bulkcopy.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. package mssql
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. type Bulk struct {
  14. // ctx is used only for AddRow and Done methods.
  15. // This could be removed if AddRow and Done accepted
  16. // a ctx field as well, which is available with the
  17. // database/sql call.
  18. ctx context.Context
  19. cn *Conn
  20. metadata []columnStruct
  21. bulkColumns []columnStruct
  22. columnsName []string
  23. tablename string
  24. numRows int
  25. headerSent bool
  26. Options BulkOptions
  27. Debug bool
  28. }
  29. type BulkOptions struct {
  30. CheckConstraints bool
  31. FireTriggers bool
  32. KeepNulls bool
  33. KilobytesPerBatch int
  34. RowsPerBatch int
  35. Order []string
  36. Tablock bool
  37. }
  38. type DataValue interface{}
  39. func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
  40. b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
  41. b.Debug = false
  42. return &b
  43. }
  44. func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
  45. b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
  46. b.Debug = false
  47. return &b
  48. }
  49. func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
  50. //get table columns info
  51. err = b.getMetadata(ctx)
  52. if err != nil {
  53. return err
  54. }
  55. //match the columns
  56. for _, colname := range b.columnsName {
  57. var bulkCol *columnStruct
  58. for _, m := range b.metadata {
  59. if m.ColName == colname {
  60. bulkCol = &m
  61. break
  62. }
  63. }
  64. if bulkCol != nil {
  65. if bulkCol.ti.TypeId == typeUdt {
  66. //send udt as binary
  67. bulkCol.ti.TypeId = typeBigVarBin
  68. }
  69. b.bulkColumns = append(b.bulkColumns, *bulkCol)
  70. b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
  71. } else {
  72. return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
  73. }
  74. }
  75. //create the bulk command
  76. //columns definitions
  77. var col_defs bytes.Buffer
  78. for i, col := range b.bulkColumns {
  79. if i != 0 {
  80. col_defs.WriteString(", ")
  81. }
  82. col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
  83. }
  84. //options
  85. var with_opts []string
  86. if b.Options.CheckConstraints {
  87. with_opts = append(with_opts, "CHECK_CONSTRAINTS")
  88. }
  89. if b.Options.FireTriggers {
  90. with_opts = append(with_opts, "FIRE_TRIGGERS")
  91. }
  92. if b.Options.KeepNulls {
  93. with_opts = append(with_opts, "KEEP_NULLS")
  94. }
  95. if b.Options.KilobytesPerBatch > 0 {
  96. with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
  97. }
  98. if b.Options.RowsPerBatch > 0 {
  99. with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
  100. }
  101. if len(b.Options.Order) > 0 {
  102. with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
  103. }
  104. if b.Options.Tablock {
  105. with_opts = append(with_opts, "TABLOCK")
  106. }
  107. var with_part string
  108. if len(with_opts) > 0 {
  109. with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
  110. }
  111. query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
  112. stmt, err := b.cn.PrepareContext(ctx, query)
  113. if err != nil {
  114. return fmt.Errorf("Prepare failed: %s", err.Error())
  115. }
  116. b.dlogf(query)
  117. _, err = stmt.(*Stmt).ExecContext(ctx, nil)
  118. if err != nil {
  119. return err
  120. }
  121. b.headerSent = true
  122. var buf = b.cn.sess.buf
  123. buf.BeginPacket(packBulkLoadBCP, false)
  124. // Send the columns metadata.
  125. columnMetadata := b.createColMetadata()
  126. _, err = buf.Write(columnMetadata)
  127. return
  128. }
  129. // AddRow immediately writes the row to the destination table.
  130. // The arguments are the row values in the order they were specified.
  131. func (b *Bulk) AddRow(row []interface{}) (err error) {
  132. if !b.headerSent {
  133. err = b.sendBulkCommand(b.ctx)
  134. if err != nil {
  135. return
  136. }
  137. }
  138. if len(row) != len(b.bulkColumns) {
  139. return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
  140. len(row), len(b.bulkColumns))
  141. }
  142. bytes, err := b.makeRowData(row)
  143. if err != nil {
  144. return
  145. }
  146. _, err = b.cn.sess.buf.Write(bytes)
  147. if err != nil {
  148. return
  149. }
  150. b.numRows = b.numRows + 1
  151. return
  152. }
  153. func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
  154. buf := new(bytes.Buffer)
  155. buf.WriteByte(byte(tokenRow))
  156. var logcol bytes.Buffer
  157. for i, col := range b.bulkColumns {
  158. if b.Debug {
  159. logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
  160. }
  161. param, err := b.makeParam(row[i], col)
  162. if err != nil {
  163. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  164. }
  165. if col.ti.Writer == nil {
  166. return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
  167. col.ColName, col.ti.TypeId)
  168. }
  169. err = col.ti.Writer(buf, param.ti, param.buffer)
  170. if err != nil {
  171. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  172. }
  173. }
  174. b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
  175. return buf.Bytes(), nil
  176. }
  177. func (b *Bulk) Done() (rowcount int64, err error) {
  178. if b.headerSent == false {
  179. //no rows had been sent
  180. return 0, nil
  181. }
  182. var buf = b.cn.sess.buf
  183. buf.WriteByte(byte(tokenDone))
  184. binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
  185. binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
  186. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  187. binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
  188. } else {
  189. binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
  190. }
  191. buf.FinishPacket()
  192. tokchan := make(chan tokenStruct, 5)
  193. go processResponse(b.ctx, b.cn.sess, tokchan, nil)
  194. var rowCount int64
  195. for token := range tokchan {
  196. switch token := token.(type) {
  197. case doneStruct:
  198. if token.Status&doneCount != 0 {
  199. rowCount = int64(token.RowCount)
  200. }
  201. if token.isError() {
  202. return 0, token.getError()
  203. }
  204. case error:
  205. return 0, b.cn.checkBadConn(token)
  206. }
  207. }
  208. return rowCount, nil
  209. }
  210. func (b *Bulk) createColMetadata() []byte {
  211. buf := new(bytes.Buffer)
  212. buf.WriteByte(byte(tokenColMetadata)) // token
  213. binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
  214. for i, col := range b.bulkColumns {
  215. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  216. binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
  217. } else {
  218. binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
  219. }
  220. binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
  221. writeTypeInfo(buf, &b.bulkColumns[i].ti)
  222. if col.ti.TypeId == typeNText ||
  223. col.ti.TypeId == typeText ||
  224. col.ti.TypeId == typeImage {
  225. tablename_ucs2 := str2ucs2(b.tablename)
  226. binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
  227. buf.Write(tablename_ucs2)
  228. }
  229. colname_ucs2 := str2ucs2(col.ColName)
  230. buf.WriteByte(uint8(len(colname_ucs2) / 2))
  231. buf.Write(colname_ucs2)
  232. }
  233. return buf.Bytes()
  234. }
  235. func (b *Bulk) getMetadata(ctx context.Context) (err error) {
  236. stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
  237. if err != nil {
  238. return
  239. }
  240. _, err = stmt.ExecContext(ctx, nil)
  241. if err != nil {
  242. return
  243. }
  244. // Get columns info.
  245. stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
  246. if err != nil {
  247. return
  248. }
  249. rows, err := stmt.QueryContext(ctx, nil)
  250. if err != nil {
  251. return fmt.Errorf("get columns info failed: %v", err)
  252. }
  253. b.metadata = rows.(*Rows).cols
  254. if b.Debug {
  255. for _, col := range b.metadata {
  256. b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
  257. col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
  258. col.Flags, col.ti.Collation.LcidAndFlags)
  259. }
  260. }
  261. return rows.Close()
  262. }
  263. func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
  264. res.ti.Size = col.ti.Size
  265. res.ti.TypeId = col.ti.TypeId
  266. if val == nil {
  267. res.ti.Size = 0
  268. return
  269. }
  270. switch col.ti.TypeId {
  271. case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
  272. var intvalue int64
  273. switch val := val.(type) {
  274. case int:
  275. intvalue = int64(val)
  276. case int32:
  277. intvalue = int64(val)
  278. case int64:
  279. intvalue = val
  280. default:
  281. err = fmt.Errorf("mssql: invalid type for int column")
  282. return
  283. }
  284. res.buffer = make([]byte, res.ti.Size)
  285. if col.ti.Size == 1 {
  286. res.buffer[0] = byte(intvalue)
  287. } else if col.ti.Size == 2 {
  288. binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
  289. } else if col.ti.Size == 4 {
  290. binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
  291. } else if col.ti.Size == 8 {
  292. binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
  293. }
  294. case typeFlt4, typeFlt8, typeFltN:
  295. var floatvalue float64
  296. switch val := val.(type) {
  297. case float32:
  298. floatvalue = float64(val)
  299. case float64:
  300. floatvalue = val
  301. case int:
  302. floatvalue = float64(val)
  303. case int64:
  304. floatvalue = float64(val)
  305. default:
  306. err = fmt.Errorf("mssql: invalid type for float column: %s", val)
  307. return
  308. }
  309. if col.ti.Size == 4 {
  310. res.buffer = make([]byte, 4)
  311. binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
  312. } else if col.ti.Size == 8 {
  313. res.buffer = make([]byte, 8)
  314. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
  315. }
  316. case typeNVarChar, typeNText, typeNChar:
  317. switch val := val.(type) {
  318. case string:
  319. res.buffer = str2ucs2(val)
  320. case []byte:
  321. res.buffer = val
  322. default:
  323. err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
  324. return
  325. }
  326. res.ti.Size = len(res.buffer)
  327. case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
  328. switch val := val.(type) {
  329. case string:
  330. res.buffer = []byte(val)
  331. case []byte:
  332. res.buffer = val
  333. default:
  334. err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
  335. return
  336. }
  337. res.ti.Size = len(res.buffer)
  338. case typeBit, typeBitN:
  339. if reflect.TypeOf(val).Kind() != reflect.Bool {
  340. err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
  341. return
  342. }
  343. res.ti.TypeId = typeBitN
  344. res.ti.Size = 1
  345. res.buffer = make([]byte, 1)
  346. if val.(bool) {
  347. res.buffer[0] = 1
  348. }
  349. case typeDateTime2N, typeDateTimeOffsetN:
  350. switch val := val.(type) {
  351. case time.Time:
  352. days, ns := dateTime2(val)
  353. ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000)
  354. var data = make([]byte, 5)
  355. data[0] = byte(ns)
  356. data[1] = byte(ns >> 8)
  357. data[2] = byte(ns >> 16)
  358. data[3] = byte(ns >> 24)
  359. data[4] = byte(ns >> 32)
  360. if col.ti.Scale <= 2 {
  361. res.ti.Size = 6
  362. } else if col.ti.Scale <= 4 {
  363. res.ti.Size = 7
  364. } else {
  365. res.ti.Size = 8
  366. }
  367. var buf []byte
  368. buf = make([]byte, res.ti.Size)
  369. copy(buf, data[0:res.ti.Size-3])
  370. buf[res.ti.Size-3] = byte(days)
  371. buf[res.ti.Size-2] = byte(days >> 8)
  372. buf[res.ti.Size-1] = byte(days >> 16)
  373. if col.ti.TypeId == typeDateTimeOffsetN {
  374. _, offset := val.Zone()
  375. var offsetMinute = uint16(offset / 60)
  376. buf = append(buf, byte(offsetMinute))
  377. buf = append(buf, byte(offsetMinute>>8))
  378. res.ti.Size = res.ti.Size + 2
  379. }
  380. res.buffer = buf
  381. default:
  382. err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
  383. return
  384. }
  385. case typeDateN:
  386. switch val := val.(type) {
  387. case time.Time:
  388. days, _ := dateTime2(val)
  389. res.ti.Size = 3
  390. res.buffer = make([]byte, 3)
  391. res.buffer[0] = byte(days)
  392. res.buffer[1] = byte(days >> 8)
  393. res.buffer[2] = byte(days >> 16)
  394. default:
  395. err = fmt.Errorf("mssql: invalid type for date column: %s", val)
  396. return
  397. }
  398. case typeDateTime, typeDateTimeN, typeDateTim4:
  399. switch val := val.(type) {
  400. case time.Time:
  401. if col.ti.Size == 4 {
  402. res.ti.Size = 4
  403. res.buffer = make([]byte, 4)
  404. ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
  405. dur := val.Sub(ref)
  406. days := dur / (24 * time.Hour)
  407. if days < 0 {
  408. err = fmt.Errorf("mssql: Date %s is out of range", val)
  409. return
  410. }
  411. mins := val.Hour()*60 + val.Minute()
  412. binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days))
  413. binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins))
  414. } else if col.ti.Size == 8 {
  415. res.ti.Size = 8
  416. res.buffer = make([]byte, 8)
  417. days := divFloor(val.Unix(), 24*60*60)
  418. //25567 - number of days since Jan 1 1900 UTC to Jan 1 1970
  419. days = days + 25567
  420. tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3)
  421. binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
  422. binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
  423. } else {
  424. err = fmt.Errorf("mssql: invalid size of column")
  425. }
  426. default:
  427. err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
  428. }
  429. // case typeMoney, typeMoney4, typeMoneyN:
  430. case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
  431. var value float64
  432. switch v := val.(type) {
  433. case int:
  434. value = float64(v)
  435. case int8:
  436. value = float64(v)
  437. case int16:
  438. value = float64(v)
  439. case int32:
  440. value = float64(v)
  441. case int64:
  442. value = float64(v)
  443. case float32:
  444. value = float64(v)
  445. case float64:
  446. value = v
  447. case string:
  448. if value, err = strconv.ParseFloat(v, 64); err != nil {
  449. return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
  450. }
  451. default:
  452. return res, fmt.Errorf("unknown value for decimal: %#v", v)
  453. }
  454. perc := col.ti.Prec
  455. scale := col.ti.Scale
  456. var dec Decimal
  457. dec, err = Float64ToDecimalScale(value, scale)
  458. if err != nil {
  459. return res, err
  460. }
  461. dec.prec = perc
  462. var length byte
  463. switch {
  464. case perc <= 9:
  465. length = 4
  466. case perc <= 19:
  467. length = 8
  468. case perc <= 28:
  469. length = 12
  470. default:
  471. length = 16
  472. }
  473. buf := make([]byte, length+1)
  474. // first byte length written by typeInfo.writer
  475. res.ti.Size = int(length) + 1
  476. // second byte sign
  477. if value < 0 {
  478. buf[0] = 0
  479. } else {
  480. buf[0] = 1
  481. }
  482. ub := dec.UnscaledBytes()
  483. l := len(ub)
  484. if l > int(length) {
  485. err = fmt.Errorf("decimal out of range: %s", dec)
  486. return res, err
  487. }
  488. // reverse the bytes
  489. for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
  490. buf[i] = ub[j]
  491. }
  492. res.buffer = buf
  493. case typeBigVarBin:
  494. switch val := val.(type) {
  495. case []byte:
  496. res.ti.Size = len(val)
  497. res.buffer = val
  498. default:
  499. err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
  500. return
  501. }
  502. case typeGuid:
  503. switch val := val.(type) {
  504. case []byte:
  505. res.ti.Size = len(val)
  506. res.buffer = val
  507. default:
  508. err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
  509. return
  510. }
  511. default:
  512. err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
  513. }
  514. return
  515. }
  516. func (b *Bulk) dlogf(format string, v ...interface{}) {
  517. if b.Debug {
  518. b.cn.sess.log.Printf(format, v...)
  519. }
  520. }