session_insert.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "github.com/go-xorm/core"
  12. )
  13. // Insert insert one or more beans
  14. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  15. var affected int64
  16. var err error
  17. if session.IsAutoClose {
  18. defer session.Close()
  19. }
  20. defer session.resetStatement()
  21. for _, bean := range beans {
  22. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  23. if sliceValue.Kind() == reflect.Slice {
  24. size := sliceValue.Len()
  25. if size > 0 {
  26. if session.Engine.SupportInsertMany() {
  27. cnt, err := session.innerInsertMulti(bean)
  28. if err != nil {
  29. return affected, err
  30. }
  31. affected += cnt
  32. } else {
  33. for i := 0; i < size; i++ {
  34. cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
  35. if err != nil {
  36. return affected, err
  37. }
  38. affected += cnt
  39. }
  40. }
  41. }
  42. } else {
  43. cnt, err := session.innerInsert(bean)
  44. if err != nil {
  45. return affected, err
  46. }
  47. affected += cnt
  48. }
  49. }
  50. return affected, err
  51. }
  52. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  53. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  54. if sliceValue.Kind() != reflect.Slice {
  55. return 0, errors.New("needs a pointer to a slice")
  56. }
  57. if sliceValue.Len() <= 0 {
  58. return 0, errors.New("could not insert a empty slice")
  59. }
  60. session.Statement.setRefValue(sliceValue.Index(0))
  61. if len(session.Statement.TableName()) <= 0 {
  62. return 0, ErrTableNotFound
  63. }
  64. table := session.Statement.RefTable
  65. size := sliceValue.Len()
  66. var colNames []string
  67. var colMultiPlaces []string
  68. var args []interface{}
  69. var cols []*core.Column
  70. for i := 0; i < size; i++ {
  71. v := sliceValue.Index(i)
  72. vv := reflect.Indirect(v)
  73. elemValue := v.Interface()
  74. var colPlaces []string
  75. // handle BeforeInsertProcessor
  76. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  77. for _, closure := range session.beforeClosures {
  78. closure(elemValue)
  79. }
  80. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  81. processor.BeforeInsert()
  82. }
  83. // --
  84. if i == 0 {
  85. for _, col := range table.Columns() {
  86. ptrFieldValue, err := col.ValueOfV(&vv)
  87. if err != nil {
  88. return 0, err
  89. }
  90. fieldValue := *ptrFieldValue
  91. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  92. continue
  93. }
  94. if col.MapType == core.ONLYFROMDB {
  95. continue
  96. }
  97. if col.IsDeleted {
  98. continue
  99. }
  100. if session.Statement.ColumnStr != "" {
  101. if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
  102. continue
  103. }
  104. }
  105. if session.Statement.OmitStr != "" {
  106. if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
  107. continue
  108. }
  109. }
  110. if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
  111. val, t := session.Engine.NowTime2(col.SQLType.Name)
  112. args = append(args, val)
  113. var colName = col.Name
  114. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  115. col := table.GetColumn(colName)
  116. setColumnTime(bean, col, t)
  117. })
  118. } else if col.IsVersion && session.Statement.checkVersion {
  119. args = append(args, 1)
  120. var colName = col.Name
  121. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  122. col := table.GetColumn(colName)
  123. setColumnInt(bean, col, 1)
  124. })
  125. } else {
  126. arg, err := session.value2Interface(col, fieldValue)
  127. if err != nil {
  128. return 0, err
  129. }
  130. args = append(args, arg)
  131. }
  132. colNames = append(colNames, col.Name)
  133. cols = append(cols, col)
  134. colPlaces = append(colPlaces, "?")
  135. }
  136. } else {
  137. for _, col := range cols {
  138. ptrFieldValue, err := col.ValueOfV(&vv)
  139. if err != nil {
  140. return 0, err
  141. }
  142. fieldValue := *ptrFieldValue
  143. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  144. continue
  145. }
  146. if col.MapType == core.ONLYFROMDB {
  147. continue
  148. }
  149. if col.IsDeleted {
  150. continue
  151. }
  152. if session.Statement.ColumnStr != "" {
  153. if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
  154. continue
  155. }
  156. }
  157. if session.Statement.OmitStr != "" {
  158. if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
  159. continue
  160. }
  161. }
  162. if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
  163. val, t := session.Engine.NowTime2(col.SQLType.Name)
  164. args = append(args, val)
  165. var colName = col.Name
  166. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  167. col := table.GetColumn(colName)
  168. setColumnTime(bean, col, t)
  169. })
  170. } else if col.IsVersion && session.Statement.checkVersion {
  171. args = append(args, 1)
  172. var colName = col.Name
  173. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  174. col := table.GetColumn(colName)
  175. setColumnInt(bean, col, 1)
  176. })
  177. } else {
  178. arg, err := session.value2Interface(col, fieldValue)
  179. if err != nil {
  180. return 0, err
  181. }
  182. args = append(args, arg)
  183. }
  184. colPlaces = append(colPlaces, "?")
  185. }
  186. }
  187. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  188. }
  189. cleanupProcessorsClosures(&session.beforeClosures)
  190. var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
  191. var statement string
  192. if session.Engine.dialect.DBType() == core.ORACLE {
  193. sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
  194. temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
  195. session.Engine.Quote(session.Statement.TableName()),
  196. session.Engine.QuoteStr(),
  197. strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
  198. session.Engine.QuoteStr())
  199. statement = fmt.Sprintf(sql,
  200. session.Engine.Quote(session.Statement.TableName()),
  201. session.Engine.QuoteStr(),
  202. strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
  203. session.Engine.QuoteStr(),
  204. strings.Join(colMultiPlaces, temp))
  205. } else {
  206. statement = fmt.Sprintf(sql,
  207. session.Engine.Quote(session.Statement.TableName()),
  208. session.Engine.QuoteStr(),
  209. strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
  210. session.Engine.QuoteStr(),
  211. strings.Join(colMultiPlaces, "),("))
  212. }
  213. res, err := session.exec(statement, args...)
  214. if err != nil {
  215. return 0, err
  216. }
  217. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  218. session.cacheInsert(session.Statement.TableName())
  219. }
  220. lenAfterClosures := len(session.afterClosures)
  221. for i := 0; i < size; i++ {
  222. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  223. // handle AfterInsertProcessor
  224. if session.IsAutoCommit {
  225. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  226. for _, closure := range session.afterClosures {
  227. closure(elemValue)
  228. }
  229. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  230. processor.AfterInsert()
  231. }
  232. } else {
  233. if lenAfterClosures > 0 {
  234. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  235. *value = append(*value, session.afterClosures...)
  236. } else {
  237. afterClosures := make([]func(interface{}), lenAfterClosures)
  238. copy(afterClosures, session.afterClosures)
  239. session.afterInsertBeans[elemValue] = &afterClosures
  240. }
  241. } else {
  242. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  243. session.afterInsertBeans[elemValue] = nil
  244. }
  245. }
  246. }
  247. }
  248. cleanupProcessorsClosures(&session.afterClosures)
  249. return res.RowsAffected()
  250. }
  251. // InsertMulti insert multiple records
  252. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  253. defer session.resetStatement()
  254. if session.IsAutoClose {
  255. defer session.Close()
  256. }
  257. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  258. if sliceValue.Kind() != reflect.Slice {
  259. return 0, ErrParamsType
  260. }
  261. if sliceValue.Len() <= 0 {
  262. return 0, nil
  263. }
  264. return session.innerInsertMulti(rowsSlicePtr)
  265. }
  266. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  267. session.Statement.setRefValue(rValue(bean))
  268. if len(session.Statement.TableName()) <= 0 {
  269. return 0, ErrTableNotFound
  270. }
  271. table := session.Statement.RefTable
  272. // handle BeforeInsertProcessor
  273. for _, closure := range session.beforeClosures {
  274. closure(bean)
  275. }
  276. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  277. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  278. processor.BeforeInsert()
  279. }
  280. // --
  281. colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false)
  282. if err != nil {
  283. return 0, err
  284. }
  285. // insert expr columns, override if exists
  286. exprColumns := session.Statement.getExpr()
  287. exprColVals := make([]string, 0, len(exprColumns))
  288. for _, v := range exprColumns {
  289. // remove the expr columns
  290. for i, colName := range colNames {
  291. if colName == v.colName {
  292. colNames = append(colNames[:i], colNames[i + 1:]...)
  293. args = append(args[:i], args[i + 1:]...)
  294. }
  295. }
  296. // append expr column to the end
  297. colNames = append(colNames, v.colName)
  298. exprColVals = append(exprColVals, v.expr)
  299. }
  300. colPlaces := strings.Repeat("?, ", len(colNames) - len(exprColumns))
  301. if len(exprColVals) > 0 {
  302. colPlaces = colPlaces + strings.Join(exprColVals, ", ")
  303. } else {
  304. colPlaces = colPlaces[0 : len(colPlaces) - 2]
  305. }
  306. sqlStr := fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
  307. session.Engine.Quote(session.Statement.TableName()),
  308. session.Engine.QuoteStr(),
  309. strings.Join(colNames, session.Engine.Quote(", ")),
  310. session.Engine.QuoteStr(),
  311. colPlaces)
  312. handleAfterInsertProcessorFunc := func(bean interface{}) {
  313. if session.IsAutoCommit {
  314. for _, closure := range session.afterClosures {
  315. closure(bean)
  316. }
  317. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  318. processor.AfterInsert()
  319. }
  320. } else {
  321. lenAfterClosures := len(session.afterClosures)
  322. if lenAfterClosures > 0 {
  323. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  324. *value = append(*value, session.afterClosures...)
  325. } else {
  326. afterClosures := make([]func(interface{}), lenAfterClosures)
  327. copy(afterClosures, session.afterClosures)
  328. session.afterInsertBeans[bean] = &afterClosures
  329. }
  330. } else {
  331. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  332. session.afterInsertBeans[bean] = nil
  333. }
  334. }
  335. }
  336. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  337. }
  338. // for postgres, many of them didn't implement lastInsertId, so we should
  339. // implemented it ourself.
  340. if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
  341. //assert table.AutoIncrement != ""
  342. res, err := session.query("select seq_atable.currval from dual", args...)
  343. if err != nil {
  344. return 0, err
  345. }
  346. handleAfterInsertProcessorFunc(bean)
  347. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  348. session.cacheInsert(session.Statement.TableName())
  349. }
  350. if table.Version != "" && session.Statement.checkVersion {
  351. verValue, err := table.VersionColumn().ValueOf(bean)
  352. if err != nil {
  353. session.Engine.logger.Error(err)
  354. } else if verValue.IsValid() && verValue.CanSet() {
  355. verValue.SetInt(1)
  356. }
  357. }
  358. if len(res) < 1 {
  359. return 0, errors.New("insert no error but not returned id")
  360. }
  361. idByte := res[0][table.AutoIncrement]
  362. id, err := strconv.ParseInt(string(idByte), 10, 64)
  363. if err != nil || id <= 0 {
  364. return 1, err
  365. }
  366. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  367. if err != nil {
  368. session.Engine.logger.Error(err)
  369. }
  370. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  371. return 1, nil
  372. }
  373. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  374. return 1, nil
  375. } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
  376. //assert table.AutoIncrement != ""
  377. sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
  378. res, err := session.query(sqlStr, args...)
  379. if err != nil {
  380. return 0, err
  381. }
  382. handleAfterInsertProcessorFunc(bean)
  383. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  384. session.cacheInsert(session.Statement.TableName())
  385. }
  386. if table.Version != "" && session.Statement.checkVersion {
  387. verValue, err := table.VersionColumn().ValueOf(bean)
  388. if err != nil {
  389. session.Engine.logger.Error(err)
  390. } else if verValue.IsValid() && verValue.CanSet() {
  391. verValue.SetInt(1)
  392. }
  393. }
  394. if len(res) < 1 {
  395. return 0, errors.New("insert no error but not returned id")
  396. }
  397. idByte := res[0][table.AutoIncrement]
  398. id, err := strconv.ParseInt(string(idByte), 10, 64)
  399. if err != nil || id <= 0 {
  400. return 1, err
  401. }
  402. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  403. if err != nil {
  404. session.Engine.logger.Error(err)
  405. }
  406. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  407. return 1, nil
  408. }
  409. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  410. return 1, nil
  411. } else {
  412. res, err := session.exec(sqlStr, args...)
  413. if err != nil {
  414. return 0, err
  415. }
  416. defer handleAfterInsertProcessorFunc(bean)
  417. if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
  418. session.cacheInsert(session.Statement.TableName())
  419. }
  420. if table.Version != "" && session.Statement.checkVersion {
  421. verValue, err := table.VersionColumn().ValueOf(bean)
  422. if err != nil {
  423. session.Engine.logger.Error(err)
  424. } else if verValue.IsValid() && verValue.CanSet() {
  425. verValue.SetInt(1)
  426. }
  427. }
  428. if table.AutoIncrement == "" {
  429. return res.RowsAffected()
  430. }
  431. var id int64
  432. id, err = res.LastInsertId()
  433. if err != nil || id <= 0 {
  434. return res.RowsAffected()
  435. }
  436. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  437. if err != nil {
  438. session.Engine.logger.Error(err)
  439. }
  440. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  441. return res.RowsAffected()
  442. }
  443. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  444. return res.RowsAffected()
  445. }
  446. }
  447. // InsertOne insert only one struct into database as a record.
  448. // The in parameter bean must a struct or a point to struct. The return
  449. // parameter is inserted and error
  450. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  451. defer session.resetStatement()
  452. if session.IsAutoClose {
  453. defer session.Close()
  454. }
  455. return session.innerInsert(bean)
  456. }
  457. func (session *Session) cacheInsert(tables ...string) error {
  458. if session.Statement.RefTable == nil {
  459. return ErrCacheFailed
  460. }
  461. table := session.Statement.RefTable
  462. cacher := session.Engine.getCacher2(table)
  463. for _, t := range tables {
  464. session.Engine.logger.Debug("[cache] clear sql:", t)
  465. cacher.ClearIds(t)
  466. }
  467. return nil
  468. }