statement.go 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/go-xorm/core"
  16. )
  17. type incrParam struct {
  18. colName string
  19. arg interface{}
  20. }
  21. type decrParam struct {
  22. colName string
  23. arg interface{}
  24. }
  25. type exprParam struct {
  26. colName string
  27. expr string
  28. }
  29. // Statement save all the sql info for executing SQL
  30. type Statement struct {
  31. RefTable *core.Table
  32. Engine *Engine
  33. Start int
  34. LimitN int
  35. idParam *core.PK
  36. OrderStr string
  37. JoinStr string
  38. joinArgs []interface{}
  39. GroupByStr string
  40. HavingStr string
  41. ColumnStr string
  42. selectStr string
  43. columnMap map[string]bool
  44. useAllCols bool
  45. OmitStr string
  46. AltTableName string
  47. tableName string
  48. RawSQL string
  49. RawParams []interface{}
  50. UseCascade bool
  51. UseAutoJoin bool
  52. StoreEngine string
  53. Charset string
  54. UseCache bool
  55. UseAutoTime bool
  56. noAutoCondition bool
  57. IsDistinct bool
  58. IsForUpdate bool
  59. TableAlias string
  60. allUseBool bool
  61. checkVersion bool
  62. unscoped bool
  63. mustColumnMap map[string]bool
  64. nullableMap map[string]bool
  65. incrColumns map[string]incrParam
  66. decrColumns map[string]decrParam
  67. exprColumns map[string]exprParam
  68. cond builder.Cond
  69. }
  70. // Init reset all the statement's fields
  71. func (statement *Statement) Init() {
  72. statement.RefTable = nil
  73. statement.Start = 0
  74. statement.LimitN = 0
  75. statement.OrderStr = ""
  76. statement.UseCascade = true
  77. statement.JoinStr = ""
  78. statement.joinArgs = make([]interface{}, 0)
  79. statement.GroupByStr = ""
  80. statement.HavingStr = ""
  81. statement.ColumnStr = ""
  82. statement.OmitStr = ""
  83. statement.columnMap = make(map[string]bool)
  84. statement.AltTableName = ""
  85. statement.tableName = ""
  86. statement.idParam = nil
  87. statement.RawSQL = ""
  88. statement.RawParams = make([]interface{}, 0)
  89. statement.UseCache = true
  90. statement.UseAutoTime = true
  91. statement.noAutoCondition = false
  92. statement.IsDistinct = false
  93. statement.IsForUpdate = false
  94. statement.TableAlias = ""
  95. statement.selectStr = ""
  96. statement.allUseBool = false
  97. statement.useAllCols = false
  98. statement.mustColumnMap = make(map[string]bool)
  99. statement.nullableMap = make(map[string]bool)
  100. statement.checkVersion = true
  101. statement.unscoped = false
  102. statement.incrColumns = make(map[string]incrParam)
  103. statement.decrColumns = make(map[string]decrParam)
  104. statement.exprColumns = make(map[string]exprParam)
  105. statement.cond = builder.NewCond()
  106. }
  107. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  108. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  109. statement.noAutoCondition = true
  110. if len(no) > 0 {
  111. statement.noAutoCondition = no[0]
  112. }
  113. return statement
  114. }
  115. // Alias set the table alias
  116. func (statement *Statement) Alias(alias string) *Statement {
  117. statement.TableAlias = alias
  118. return statement
  119. }
  120. // SQL adds raw sql statement
  121. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  122. switch query.(type) {
  123. case (*builder.Builder):
  124. var err error
  125. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  126. if err != nil {
  127. statement.Engine.logger.Error(err)
  128. }
  129. case string:
  130. statement.RawSQL = query.(string)
  131. statement.RawParams = args
  132. default:
  133. statement.Engine.logger.Error("unsupported sql type")
  134. }
  135. return statement
  136. }
  137. // Where add Where statement
  138. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  139. return statement.And(query, args...)
  140. }
  141. // And add Where & and statement
  142. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  143. switch query.(type) {
  144. case string:
  145. cond := builder.Expr(query.(string), args...)
  146. statement.cond = statement.cond.And(cond)
  147. case builder.Cond:
  148. cond := query.(builder.Cond)
  149. statement.cond = statement.cond.And(cond)
  150. for _, v := range args {
  151. if vv, ok := v.(builder.Cond); ok {
  152. statement.cond = statement.cond.And(vv)
  153. }
  154. }
  155. default:
  156. // TODO: not support condition type
  157. }
  158. return statement
  159. }
  160. // Or add Where & Or statement
  161. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  162. switch query.(type) {
  163. case string:
  164. cond := builder.Expr(query.(string), args...)
  165. statement.cond = statement.cond.Or(cond)
  166. case builder.Cond:
  167. cond := query.(builder.Cond)
  168. statement.cond = statement.cond.Or(cond)
  169. for _, v := range args {
  170. if vv, ok := v.(builder.Cond); ok {
  171. statement.cond = statement.cond.Or(vv)
  172. }
  173. }
  174. default:
  175. // TODO: not support condition type
  176. }
  177. return statement
  178. }
  179. // In generate "Where column IN (?) " statement
  180. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  181. in := builder.In(statement.Engine.Quote(column), args...)
  182. statement.cond = statement.cond.And(in)
  183. return statement
  184. }
  185. // NotIn generate "Where column NOT IN (?) " statement
  186. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  187. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  188. statement.cond = statement.cond.And(notIn)
  189. return statement
  190. }
  191. func (statement *Statement) setRefValue(v reflect.Value) {
  192. statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v))
  193. statement.tableName = statement.Engine.tbName(v)
  194. }
  195. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  196. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  197. v := rValue(tableNameOrBean)
  198. t := v.Type()
  199. if t.Kind() == reflect.String {
  200. statement.AltTableName = tableNameOrBean.(string)
  201. } else if t.Kind() == reflect.Struct {
  202. statement.RefTable = statement.Engine.autoMapType(v)
  203. statement.AltTableName = statement.Engine.tbName(v)
  204. }
  205. return statement
  206. }
  207. // Auto generating update columnes and values according a struct
  208. func buildUpdates(engine *Engine, table *core.Table, bean interface{},
  209. includeVersion bool, includeUpdated bool, includeNil bool,
  210. includeAutoIncr bool, allUseBool bool, useAllCols bool,
  211. mustColumnMap map[string]bool, nullableMap map[string]bool,
  212. columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
  213. var colNames = make([]string, 0)
  214. var args = make([]interface{}, 0)
  215. for _, col := range table.Columns() {
  216. if !includeVersion && col.IsVersion {
  217. continue
  218. }
  219. if col.IsCreated {
  220. continue
  221. }
  222. if !includeUpdated && col.IsUpdated {
  223. continue
  224. }
  225. if !includeAutoIncr && col.IsAutoIncrement {
  226. continue
  227. }
  228. if col.IsDeleted && !unscoped {
  229. continue
  230. }
  231. if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use {
  232. continue
  233. }
  234. fieldValuePtr, err := col.ValueOf(bean)
  235. if err != nil {
  236. engine.logger.Error(err)
  237. continue
  238. }
  239. fieldValue := *fieldValuePtr
  240. fieldType := reflect.TypeOf(fieldValue.Interface())
  241. requiredField := useAllCols
  242. includeNil := useAllCols
  243. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  244. if b {
  245. requiredField = true
  246. } else {
  247. continue
  248. }
  249. }
  250. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  251. if b, ok := getFlagForColumn(nullableMap, col); ok {
  252. if b && col.Nullable && isZero(fieldValue.Interface()) {
  253. var nilValue *int
  254. fieldValue = reflect.ValueOf(nilValue)
  255. fieldType = reflect.TypeOf(fieldValue.Interface())
  256. includeNil = true
  257. }
  258. }
  259. var val interface{}
  260. if fieldValue.CanAddr() {
  261. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  262. data, err := structConvert.ToDB()
  263. if err != nil {
  264. engine.logger.Error(err)
  265. } else {
  266. val = data
  267. }
  268. goto APPEND
  269. }
  270. }
  271. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  272. data, err := structConvert.ToDB()
  273. if err != nil {
  274. engine.logger.Error(err)
  275. } else {
  276. val = data
  277. }
  278. goto APPEND
  279. }
  280. if fieldType.Kind() == reflect.Ptr {
  281. if fieldValue.IsNil() {
  282. if includeNil {
  283. args = append(args, nil)
  284. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  285. }
  286. continue
  287. } else if !fieldValue.IsValid() {
  288. continue
  289. } else {
  290. // dereference ptr type to instance type
  291. fieldValue = fieldValue.Elem()
  292. fieldType = reflect.TypeOf(fieldValue.Interface())
  293. requiredField = true
  294. }
  295. }
  296. switch fieldType.Kind() {
  297. case reflect.Bool:
  298. if allUseBool || requiredField {
  299. val = fieldValue.Interface()
  300. } else {
  301. // if a bool in a struct, it will not be as a condition because it default is false,
  302. // please use Where() instead
  303. continue
  304. }
  305. case reflect.String:
  306. if !requiredField && fieldValue.String() == "" {
  307. continue
  308. }
  309. // for MyString, should convert to string or panic
  310. if fieldType.String() != reflect.String.String() {
  311. val = fieldValue.String()
  312. } else {
  313. val = fieldValue.Interface()
  314. }
  315. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  316. if !requiredField && fieldValue.Int() == 0 {
  317. continue
  318. }
  319. val = fieldValue.Interface()
  320. case reflect.Float32, reflect.Float64:
  321. if !requiredField && fieldValue.Float() == 0.0 {
  322. continue
  323. }
  324. val = fieldValue.Interface()
  325. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  326. if !requiredField && fieldValue.Uint() == 0 {
  327. continue
  328. }
  329. t := int64(fieldValue.Uint())
  330. val = reflect.ValueOf(&t).Interface()
  331. case reflect.Struct:
  332. if fieldType.ConvertibleTo(core.TimeType) {
  333. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  334. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  335. continue
  336. }
  337. val = engine.FormatTime(col.SQLType.Name, t)
  338. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  339. val, _ = nulType.Value()
  340. } else {
  341. if !col.SQLType.IsJson() {
  342. engine.autoMapType(fieldValue)
  343. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  344. if len(table.PrimaryKeys) == 1 {
  345. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  346. // fix non-int pk issues
  347. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  348. val = pkField.Interface()
  349. } else {
  350. continue
  351. }
  352. } else {
  353. //TODO: how to handler?
  354. panic("not supported")
  355. }
  356. } else {
  357. val = fieldValue.Interface()
  358. }
  359. } else {
  360. // Blank struct could not be as update data
  361. if requiredField || !isStructZero(fieldValue) {
  362. bytes, err := json.Marshal(fieldValue.Interface())
  363. if err != nil {
  364. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  365. }
  366. if col.SQLType.IsText() {
  367. val = string(bytes)
  368. } else if col.SQLType.IsBlob() {
  369. val = bytes
  370. }
  371. } else {
  372. continue
  373. }
  374. }
  375. }
  376. case reflect.Array, reflect.Slice, reflect.Map:
  377. if !requiredField {
  378. if fieldValue == reflect.Zero(fieldType) {
  379. continue
  380. }
  381. if fieldType.Kind() == reflect.Array {
  382. if isArrayValueZero(fieldValue) {
  383. continue
  384. }
  385. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  386. continue
  387. }
  388. }
  389. if col.SQLType.IsText() {
  390. bytes, err := json.Marshal(fieldValue.Interface())
  391. if err != nil {
  392. engine.logger.Error(err)
  393. continue
  394. }
  395. val = string(bytes)
  396. } else if col.SQLType.IsBlob() {
  397. var bytes []byte
  398. var err error
  399. if fieldType.Kind() == reflect.Slice &&
  400. fieldType.Elem().Kind() == reflect.Uint8 {
  401. if fieldValue.Len() > 0 {
  402. val = fieldValue.Bytes()
  403. } else {
  404. continue
  405. }
  406. } else if fieldType.Kind() == reflect.Array &&
  407. fieldType.Elem().Kind() == reflect.Uint8 {
  408. val = fieldValue.Slice(0, 0).Interface()
  409. } else {
  410. bytes, err = json.Marshal(fieldValue.Interface())
  411. if err != nil {
  412. engine.logger.Error(err)
  413. continue
  414. }
  415. val = bytes
  416. }
  417. } else {
  418. continue
  419. }
  420. default:
  421. val = fieldValue.Interface()
  422. }
  423. APPEND:
  424. args = append(args, val)
  425. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  426. continue
  427. }
  428. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  429. }
  430. return colNames, args
  431. }
  432. func (statement *Statement) needTableName() bool {
  433. return len(statement.JoinStr) > 0
  434. }
  435. func (statement *Statement) colName(col *core.Column, tableName string) string {
  436. if statement.needTableName() {
  437. var nm = tableName
  438. if len(statement.TableAlias) > 0 {
  439. nm = statement.TableAlias
  440. }
  441. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  442. }
  443. return statement.Engine.Quote(col.Name)
  444. }
  445. func buildConds(engine *Engine, table *core.Table, bean interface{},
  446. includeVersion bool, includeUpdated bool, includeNil bool,
  447. includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
  448. mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
  449. var conds []builder.Cond
  450. for _, col := range table.Columns() {
  451. if !includeVersion && col.IsVersion {
  452. continue
  453. }
  454. if !includeUpdated && col.IsUpdated {
  455. continue
  456. }
  457. if !includeAutoIncr && col.IsAutoIncrement {
  458. continue
  459. }
  460. if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
  461. continue
  462. }
  463. if col.SQLType.IsJson() {
  464. continue
  465. }
  466. var colName string
  467. if addedTableName {
  468. var nm = tableName
  469. if len(aliasName) > 0 {
  470. nm = aliasName
  471. }
  472. colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
  473. } else {
  474. colName = engine.Quote(col.Name)
  475. }
  476. fieldValuePtr, err := col.ValueOf(bean)
  477. if err != nil {
  478. engine.logger.Error(err)
  479. continue
  480. }
  481. if col.IsDeleted && !unscoped { // tag "deleted" is enabled
  482. if engine.dialect.DBType() == core.MSSQL {
  483. conds = append(conds, builder.IsNull{colName})
  484. } else {
  485. conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
  486. }
  487. }
  488. fieldValue := *fieldValuePtr
  489. if fieldValue.Interface() == nil {
  490. continue
  491. }
  492. fieldType := reflect.TypeOf(fieldValue.Interface())
  493. requiredField := useAllCols
  494. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  495. if b {
  496. requiredField = true
  497. } else {
  498. continue
  499. }
  500. }
  501. if fieldType.Kind() == reflect.Ptr {
  502. if fieldValue.IsNil() {
  503. if includeNil {
  504. conds = append(conds, builder.Eq{colName: nil})
  505. }
  506. continue
  507. } else if !fieldValue.IsValid() {
  508. continue
  509. } else {
  510. // dereference ptr type to instance type
  511. fieldValue = fieldValue.Elem()
  512. fieldType = reflect.TypeOf(fieldValue.Interface())
  513. requiredField = true
  514. }
  515. }
  516. var val interface{}
  517. switch fieldType.Kind() {
  518. case reflect.Bool:
  519. if allUseBool || requiredField {
  520. val = fieldValue.Interface()
  521. } else {
  522. // if a bool in a struct, it will not be as a condition because it default is false,
  523. // please use Where() instead
  524. continue
  525. }
  526. case reflect.String:
  527. if !requiredField && fieldValue.String() == "" {
  528. continue
  529. }
  530. // for MyString, should convert to string or panic
  531. if fieldType.String() != reflect.String.String() {
  532. val = fieldValue.String()
  533. } else {
  534. val = fieldValue.Interface()
  535. }
  536. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  537. if !requiredField && fieldValue.Int() == 0 {
  538. continue
  539. }
  540. val = fieldValue.Interface()
  541. case reflect.Float32, reflect.Float64:
  542. if !requiredField && fieldValue.Float() == 0.0 {
  543. continue
  544. }
  545. val = fieldValue.Interface()
  546. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  547. if !requiredField && fieldValue.Uint() == 0 {
  548. continue
  549. }
  550. t := int64(fieldValue.Uint())
  551. val = reflect.ValueOf(&t).Interface()
  552. case reflect.Struct:
  553. if fieldType.ConvertibleTo(core.TimeType) {
  554. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  555. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  556. continue
  557. }
  558. val = engine.FormatTime(col.SQLType.Name, t)
  559. } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
  560. continue
  561. } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
  562. val, _ = valNul.Value()
  563. if val == nil {
  564. continue
  565. }
  566. } else {
  567. if col.SQLType.IsJson() {
  568. if col.SQLType.IsText() {
  569. bytes, err := json.Marshal(fieldValue.Interface())
  570. if err != nil {
  571. engine.logger.Error(err)
  572. continue
  573. }
  574. val = string(bytes)
  575. } else if col.SQLType.IsBlob() {
  576. var bytes []byte
  577. var err error
  578. bytes, err = json.Marshal(fieldValue.Interface())
  579. if err != nil {
  580. engine.logger.Error(err)
  581. continue
  582. }
  583. val = bytes
  584. }
  585. } else {
  586. engine.autoMapType(fieldValue)
  587. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  588. if len(table.PrimaryKeys) == 1 {
  589. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  590. // fix non-int pk issues
  591. //if pkField.Int() != 0 {
  592. if pkField.IsValid() && !isZero(pkField.Interface()) {
  593. val = pkField.Interface()
  594. } else {
  595. continue
  596. }
  597. } else {
  598. //TODO: how to handler?
  599. panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
  600. }
  601. } else {
  602. val = fieldValue.Interface()
  603. }
  604. }
  605. }
  606. case reflect.Array:
  607. continue
  608. case reflect.Slice, reflect.Map:
  609. if fieldValue == reflect.Zero(fieldType) {
  610. continue
  611. }
  612. if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  613. continue
  614. }
  615. if col.SQLType.IsText() {
  616. bytes, err := json.Marshal(fieldValue.Interface())
  617. if err != nil {
  618. engine.logger.Error(err)
  619. continue
  620. }
  621. val = string(bytes)
  622. } else if col.SQLType.IsBlob() {
  623. var bytes []byte
  624. var err error
  625. if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
  626. fieldType.Elem().Kind() == reflect.Uint8 {
  627. if fieldValue.Len() > 0 {
  628. val = fieldValue.Bytes()
  629. } else {
  630. continue
  631. }
  632. } else {
  633. bytes, err = json.Marshal(fieldValue.Interface())
  634. if err != nil {
  635. engine.logger.Error(err)
  636. continue
  637. }
  638. val = bytes
  639. }
  640. } else {
  641. continue
  642. }
  643. default:
  644. val = fieldValue.Interface()
  645. }
  646. conds = append(conds, builder.Eq{colName: val})
  647. }
  648. return builder.And(conds...), nil
  649. }
  650. // TableName return current tableName
  651. func (statement *Statement) TableName() string {
  652. if statement.AltTableName != "" {
  653. return statement.AltTableName
  654. }
  655. return statement.tableName
  656. }
  657. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  658. func (statement *Statement) ID(id interface{}) *Statement {
  659. idValue := reflect.ValueOf(id)
  660. idType := reflect.TypeOf(idValue.Interface())
  661. switch idType {
  662. case ptrPkType:
  663. if pkPtr, ok := (id).(*core.PK); ok {
  664. statement.idParam = pkPtr
  665. return statement
  666. }
  667. case pkType:
  668. if pk, ok := (id).(core.PK); ok {
  669. statement.idParam = &pk
  670. return statement
  671. }
  672. }
  673. switch idType.Kind() {
  674. case reflect.String:
  675. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  676. return statement
  677. }
  678. statement.idParam = &core.PK{id}
  679. return statement
  680. }
  681. // Incr Generate "Update ... Set column = column + arg" statement
  682. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  683. k := strings.ToLower(column)
  684. if len(arg) > 0 {
  685. statement.incrColumns[k] = incrParam{column, arg[0]}
  686. } else {
  687. statement.incrColumns[k] = incrParam{column, 1}
  688. }
  689. return statement
  690. }
  691. // Decr Generate "Update ... Set column = column - arg" statement
  692. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  693. k := strings.ToLower(column)
  694. if len(arg) > 0 {
  695. statement.decrColumns[k] = decrParam{column, arg[0]}
  696. } else {
  697. statement.decrColumns[k] = decrParam{column, 1}
  698. }
  699. return statement
  700. }
  701. // SetExpr Generate "Update ... Set column = {expression}" statement
  702. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  703. k := strings.ToLower(column)
  704. statement.exprColumns[k] = exprParam{column, expression}
  705. return statement
  706. }
  707. // Generate "Update ... Set column = column + arg" statement
  708. func (statement *Statement) getInc() map[string]incrParam {
  709. return statement.incrColumns
  710. }
  711. // Generate "Update ... Set column = column - arg" statement
  712. func (statement *Statement) getDec() map[string]decrParam {
  713. return statement.decrColumns
  714. }
  715. // Generate "Update ... Set column = {expression}" statement
  716. func (statement *Statement) getExpr() map[string]exprParam {
  717. return statement.exprColumns
  718. }
  719. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  720. newColumns := make([]string, 0)
  721. for _, col := range columns {
  722. col = strings.Replace(col, "`", "", -1)
  723. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  724. ccols := strings.Split(col, ",")
  725. for _, c := range ccols {
  726. fields := strings.Split(strings.TrimSpace(c), ".")
  727. if len(fields) == 1 {
  728. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  729. } else if len(fields) == 2 {
  730. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  731. statement.Engine.quote(fields[1]))
  732. } else {
  733. panic(errors.New("unwanted colnames"))
  734. }
  735. }
  736. }
  737. return newColumns
  738. }
  739. // Distinct generates "DISTINCT col1, col2 " statement
  740. func (statement *Statement) Distinct(columns ...string) *Statement {
  741. statement.IsDistinct = true
  742. statement.Cols(columns...)
  743. return statement
  744. }
  745. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  746. func (statement *Statement) ForUpdate() *Statement {
  747. statement.IsForUpdate = true
  748. return statement
  749. }
  750. // Select replace select
  751. func (statement *Statement) Select(str string) *Statement {
  752. statement.selectStr = str
  753. return statement
  754. }
  755. // Cols generate "col1, col2" statement
  756. func (statement *Statement) Cols(columns ...string) *Statement {
  757. cols := col2NewCols(columns...)
  758. for _, nc := range cols {
  759. statement.columnMap[strings.ToLower(nc)] = true
  760. }
  761. newColumns := statement.col2NewColsWithQuote(columns...)
  762. statement.ColumnStr = strings.Join(newColumns, ", ")
  763. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  764. return statement
  765. }
  766. // AllCols update use only: update all columns
  767. func (statement *Statement) AllCols() *Statement {
  768. statement.useAllCols = true
  769. return statement
  770. }
  771. // MustCols update use only: must update columns
  772. func (statement *Statement) MustCols(columns ...string) *Statement {
  773. newColumns := col2NewCols(columns...)
  774. for _, nc := range newColumns {
  775. statement.mustColumnMap[strings.ToLower(nc)] = true
  776. }
  777. return statement
  778. }
  779. // UseBool indicates that use bool fields as update contents and query contiditions
  780. func (statement *Statement) UseBool(columns ...string) *Statement {
  781. if len(columns) > 0 {
  782. statement.MustCols(columns...)
  783. } else {
  784. statement.allUseBool = true
  785. }
  786. return statement
  787. }
  788. // Omit do not use the columns
  789. func (statement *Statement) Omit(columns ...string) {
  790. newColumns := col2NewCols(columns...)
  791. for _, nc := range newColumns {
  792. statement.columnMap[strings.ToLower(nc)] = false
  793. }
  794. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  795. }
  796. // Nullable Update use only: update columns to null when value is nullable and zero-value
  797. func (statement *Statement) Nullable(columns ...string) {
  798. newColumns := col2NewCols(columns...)
  799. for _, nc := range newColumns {
  800. statement.nullableMap[strings.ToLower(nc)] = true
  801. }
  802. }
  803. // Top generate LIMIT limit statement
  804. func (statement *Statement) Top(limit int) *Statement {
  805. statement.Limit(limit)
  806. return statement
  807. }
  808. // Limit generate LIMIT start, limit statement
  809. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  810. statement.LimitN = limit
  811. if len(start) > 0 {
  812. statement.Start = start[0]
  813. }
  814. return statement
  815. }
  816. // OrderBy generate "Order By order" statement
  817. func (statement *Statement) OrderBy(order string) *Statement {
  818. if len(statement.OrderStr) > 0 {
  819. statement.OrderStr += ", "
  820. }
  821. statement.OrderStr += order
  822. return statement
  823. }
  824. // Desc generate `ORDER BY xx DESC`
  825. func (statement *Statement) Desc(colNames ...string) *Statement {
  826. var buf bytes.Buffer
  827. fmt.Fprintf(&buf, statement.OrderStr)
  828. if len(statement.OrderStr) > 0 {
  829. fmt.Fprint(&buf, ", ")
  830. }
  831. newColNames := statement.col2NewColsWithQuote(colNames...)
  832. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  833. statement.OrderStr = buf.String()
  834. return statement
  835. }
  836. // Asc provide asc order by query condition, the input parameters are columns.
  837. func (statement *Statement) Asc(colNames ...string) *Statement {
  838. var buf bytes.Buffer
  839. fmt.Fprintf(&buf, statement.OrderStr)
  840. if len(statement.OrderStr) > 0 {
  841. fmt.Fprint(&buf, ", ")
  842. }
  843. newColNames := statement.col2NewColsWithQuote(colNames...)
  844. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  845. statement.OrderStr = buf.String()
  846. return statement
  847. }
  848. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  849. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  850. var buf bytes.Buffer
  851. if len(statement.JoinStr) > 0 {
  852. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  853. } else {
  854. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  855. }
  856. switch tablename.(type) {
  857. case []string:
  858. t := tablename.([]string)
  859. if len(t) > 1 {
  860. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
  861. } else if len(t) == 1 {
  862. fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
  863. }
  864. case []interface{}:
  865. t := tablename.([]interface{})
  866. l := len(t)
  867. var table string
  868. if l > 0 {
  869. f := t[0]
  870. v := rValue(f)
  871. t := v.Type()
  872. if t.Kind() == reflect.String {
  873. table = f.(string)
  874. } else if t.Kind() == reflect.Struct {
  875. table = statement.Engine.tbName(v)
  876. }
  877. }
  878. if l > 1 {
  879. fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
  880. statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
  881. } else if l == 1 {
  882. fmt.Fprintf(&buf, statement.Engine.Quote(table))
  883. }
  884. default:
  885. fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
  886. }
  887. fmt.Fprintf(&buf, " ON %v", condition)
  888. statement.JoinStr = buf.String()
  889. statement.joinArgs = append(statement.joinArgs, args...)
  890. return statement
  891. }
  892. // GroupBy generate "Group By keys" statement
  893. func (statement *Statement) GroupBy(keys string) *Statement {
  894. statement.GroupByStr = keys
  895. return statement
  896. }
  897. // Having generate "Having conditions" statement
  898. func (statement *Statement) Having(conditions string) *Statement {
  899. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  900. return statement
  901. }
  902. // Unscoped always disable struct tag "deleted"
  903. func (statement *Statement) Unscoped() *Statement {
  904. statement.unscoped = true
  905. return statement
  906. }
  907. func (statement *Statement) genColumnStr() string {
  908. var buf bytes.Buffer
  909. if statement.RefTable == nil {
  910. return ""
  911. }
  912. columns := statement.RefTable.Columns()
  913. for _, col := range columns {
  914. if statement.OmitStr != "" {
  915. if _, ok := getFlagForColumn(statement.columnMap, col); ok {
  916. continue
  917. }
  918. }
  919. if col.MapType == core.ONLYTODB {
  920. continue
  921. }
  922. if buf.Len() != 0 {
  923. buf.WriteString(", ")
  924. }
  925. if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
  926. buf.WriteString("id() AS ")
  927. }
  928. if statement.JoinStr != "" {
  929. if statement.TableAlias != "" {
  930. buf.WriteString(statement.TableAlias)
  931. } else {
  932. buf.WriteString(statement.TableName())
  933. }
  934. buf.WriteString(".")
  935. }
  936. statement.Engine.QuoteTo(&buf, col.Name)
  937. }
  938. return buf.String()
  939. }
  940. func (statement *Statement) genCreateTableSQL() string {
  941. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  942. statement.StoreEngine, statement.Charset)
  943. }
  944. func (statement *Statement) genIndexSQL() []string {
  945. var sqls []string
  946. tbName := statement.TableName()
  947. quote := statement.Engine.Quote
  948. for idxName, index := range statement.RefTable.Indexes {
  949. if index.Type == core.IndexType {
  950. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
  951. quote(tbName), quote(strings.Join(index.Cols, quote(","))))
  952. sqls = append(sqls, sql)
  953. }
  954. }
  955. return sqls
  956. }
  957. func uniqueName(tableName, uqeName string) string {
  958. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  959. }
  960. func (statement *Statement) genUniqueSQL() []string {
  961. var sqls []string
  962. tbName := statement.TableName()
  963. for _, index := range statement.RefTable.Indexes {
  964. if index.Type == core.UniqueType {
  965. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  966. sqls = append(sqls, sql)
  967. }
  968. }
  969. return sqls
  970. }
  971. func (statement *Statement) genDelIndexSQL() []string {
  972. var sqls []string
  973. tbName := statement.TableName()
  974. for idxName, index := range statement.RefTable.Indexes {
  975. var rIdxName string
  976. if index.Type == core.UniqueType {
  977. rIdxName = uniqueName(tbName, idxName)
  978. } else if index.Type == core.IndexType {
  979. rIdxName = indexName(tbName, idxName)
  980. }
  981. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
  982. if statement.Engine.dialect.IndexOnTable() {
  983. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
  984. }
  985. sqls = append(sqls, sql)
  986. }
  987. return sqls
  988. }
  989. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  990. quote := statement.Engine.Quote
  991. sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
  992. col.String(statement.Engine.dialect))
  993. return sql, []interface{}{}
  994. }
  995. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  996. return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  997. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  998. }
  999. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  1000. if !statement.noAutoCondition {
  1001. var addedTableName = (len(statement.JoinStr) > 0)
  1002. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  1003. if err != nil {
  1004. return "", nil, err
  1005. }
  1006. statement.cond = statement.cond.And(autoCond)
  1007. }
  1008. statement.processIDParam()
  1009. return builder.ToSQL(statement.cond)
  1010. }
  1011. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
  1012. statement.setRefValue(rValue(bean))
  1013. var columnStr = statement.ColumnStr
  1014. if len(statement.selectStr) > 0 {
  1015. columnStr = statement.selectStr
  1016. } else {
  1017. // TODO: always generate column names, not use * even if join
  1018. if len(statement.JoinStr) == 0 {
  1019. if len(columnStr) == 0 {
  1020. if len(statement.GroupByStr) > 0 {
  1021. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  1022. } else {
  1023. columnStr = statement.genColumnStr()
  1024. }
  1025. }
  1026. } else {
  1027. if len(columnStr) == 0 {
  1028. if len(statement.GroupByStr) > 0 {
  1029. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  1030. } else {
  1031. columnStr = "*"
  1032. }
  1033. }
  1034. }
  1035. }
  1036. condSQL, condArgs, _ := statement.genConds(bean)
  1037. return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
  1038. }
  1039. func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
  1040. statement.setRefValue(rValue(bean))
  1041. condSQL, condArgs, _ := statement.genConds(bean)
  1042. var selectSQL = statement.selectStr
  1043. if len(selectSQL) <= 0 {
  1044. if statement.IsDistinct {
  1045. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  1046. } else {
  1047. selectSQL = "count(*)"
  1048. }
  1049. }
  1050. return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
  1051. }
  1052. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
  1053. statement.setRefValue(rValue(bean))
  1054. var sumStrs = make([]string, 0, len(columns))
  1055. for _, colName := range columns {
  1056. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", statement.Engine.Quote(colName)))
  1057. }
  1058. condSQL, condArgs, _ := statement.genConds(bean)
  1059. return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
  1060. }
  1061. func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
  1062. var distinct string
  1063. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  1064. distinct = "DISTINCT "
  1065. }
  1066. var dialect = statement.Engine.Dialect()
  1067. var quote = statement.Engine.Quote
  1068. var top string
  1069. var mssqlCondi string
  1070. statement.processIDParam()
  1071. var buf bytes.Buffer
  1072. if len(condSQL) > 0 {
  1073. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  1074. }
  1075. var whereStr = buf.String()
  1076. var fromStr = " FROM " + quote(statement.TableName())
  1077. if statement.TableAlias != "" {
  1078. if dialect.DBType() == core.ORACLE {
  1079. fromStr += " " + quote(statement.TableAlias)
  1080. } else {
  1081. fromStr += " AS " + quote(statement.TableAlias)
  1082. }
  1083. }
  1084. if statement.JoinStr != "" {
  1085. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  1086. }
  1087. if dialect.DBType() == core.MSSQL {
  1088. if statement.LimitN > 0 {
  1089. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  1090. }
  1091. if statement.Start > 0 {
  1092. var column string
  1093. if len(statement.RefTable.PKColumns()) == 0 {
  1094. for _, index := range statement.RefTable.Indexes {
  1095. if len(index.Cols) == 1 {
  1096. column = index.Cols[0]
  1097. break
  1098. }
  1099. }
  1100. if len(column) == 0 {
  1101. column = statement.RefTable.ColumnsSeq()[0]
  1102. }
  1103. } else {
  1104. column = statement.RefTable.PKColumns()[0].Name
  1105. }
  1106. if statement.needTableName() {
  1107. if len(statement.TableAlias) > 0 {
  1108. column = statement.TableAlias + "." + column
  1109. } else {
  1110. column = statement.TableName() + "." + column
  1111. }
  1112. }
  1113. var orderStr string
  1114. if len(statement.OrderStr) > 0 {
  1115. orderStr = " ORDER BY " + statement.OrderStr
  1116. }
  1117. var groupStr string
  1118. if len(statement.GroupByStr) > 0 {
  1119. groupStr = " GROUP BY " + statement.GroupByStr
  1120. }
  1121. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  1122. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  1123. }
  1124. }
  1125. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  1126. a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr)
  1127. if len(mssqlCondi) > 0 {
  1128. if len(whereStr) > 0 {
  1129. a += " AND " + mssqlCondi
  1130. } else {
  1131. a += " WHERE " + mssqlCondi
  1132. }
  1133. }
  1134. if statement.GroupByStr != "" {
  1135. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1136. }
  1137. if statement.HavingStr != "" {
  1138. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1139. }
  1140. if statement.OrderStr != "" {
  1141. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1142. }
  1143. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1144. if statement.Start > 0 {
  1145. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1146. } else if statement.LimitN > 0 {
  1147. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1148. }
  1149. } else if dialect.DBType() == core.ORACLE {
  1150. if statement.Start != 0 || statement.LimitN != 0 {
  1151. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1152. }
  1153. }
  1154. if statement.IsForUpdate {
  1155. a = dialect.ForUpdateSql(a)
  1156. }
  1157. return
  1158. }
  1159. func (statement *Statement) processIDParam() {
  1160. if statement.idParam == nil {
  1161. return
  1162. }
  1163. for i, col := range statement.RefTable.PKColumns() {
  1164. var colName = statement.colName(col, statement.TableName())
  1165. if i < len(*(statement.idParam)) {
  1166. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1167. } else {
  1168. statement.cond = statement.cond.And(builder.Eq{colName: ""})
  1169. }
  1170. }
  1171. }
  1172. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1173. var colnames = make([]string, len(cols))
  1174. for i, col := range cols {
  1175. if includeTableName {
  1176. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1177. "." + statement.Engine.Quote(col.Name)
  1178. } else {
  1179. colnames[i] = statement.Engine.Quote(col.Name)
  1180. }
  1181. }
  1182. return strings.Join(colnames, ", ")
  1183. }
  1184. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1185. if statement.RefTable != nil {
  1186. cols := statement.RefTable.PKColumns()
  1187. if len(cols) == 0 {
  1188. return ""
  1189. }
  1190. colstrs := statement.joinColumns(cols, false)
  1191. sqls := splitNNoCase(sqlStr, " from ", 2)
  1192. if len(sqls) != 2 {
  1193. return ""
  1194. }
  1195. var top string
  1196. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1197. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1198. }
  1199. return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1200. }
  1201. return ""
  1202. }
  1203. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1204. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1205. return "", ""
  1206. }
  1207. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1208. sqls := splitNNoCase(sqlStr, "where", 2)
  1209. if len(sqls) != 2 {
  1210. if len(sqls) == 1 {
  1211. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1212. colstrs, statement.Engine.Quote(statement.TableName()))
  1213. }
  1214. return "", ""
  1215. }
  1216. var whereStr = sqls[1]
  1217. //TODO: for postgres only, if any other database?
  1218. var paraStr string
  1219. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1220. paraStr = "$"
  1221. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1222. paraStr = ":"
  1223. }
  1224. if paraStr != "" {
  1225. if strings.Contains(sqls[1], paraStr) {
  1226. dollers := strings.Split(sqls[1], paraStr)
  1227. whereStr = dollers[0]
  1228. for i, c := range dollers[1:] {
  1229. ccs := strings.SplitN(c, " ", 2)
  1230. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1231. }
  1232. }
  1233. }
  1234. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1235. colstrs, statement.Engine.Quote(statement.TableName()),
  1236. whereStr)
  1237. }