mssql.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. package mssql
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "math"
  11. "net"
  12. "reflect"
  13. "strings"
  14. "time"
  15. "unicode"
  16. )
  17. // ReturnStatus may be used to return the return value from a proc.
  18. //
  19. // var rs mssql.ReturnStatus
  20. // _, err := db.Exec("theproc", &rs)
  21. // log.Printf("return status = %d", rs)
  22. type ReturnStatus int32
  23. var driverInstance = &Driver{processQueryText: true}
  24. var driverInstanceNoProcess = &Driver{processQueryText: false}
  25. func init() {
  26. sql.Register("mssql", driverInstance)
  27. sql.Register("sqlserver", driverInstanceNoProcess)
  28. createDialer = func(p *connectParams) Dialer {
  29. return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
  30. }
  31. }
  32. var createDialer func(p *connectParams) Dialer
  33. type netDialer struct {
  34. nd *net.Dialer
  35. }
  36. func (d netDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
  37. return d.nd.DialContext(ctx, network, addr)
  38. }
  39. type Driver struct {
  40. log optionalLogger
  41. processQueryText bool
  42. }
  43. // OpenConnector opens a new connector. Useful to dial with a context.
  44. func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
  45. params, err := parseConnectParams(dsn)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return &Connector{
  50. params: params,
  51. driver: d,
  52. }, nil
  53. }
  54. func (d *Driver) Open(dsn string) (driver.Conn, error) {
  55. return d.open(context.Background(), dsn)
  56. }
  57. func SetLogger(logger Logger) {
  58. driverInstance.SetLogger(logger)
  59. driverInstanceNoProcess.SetLogger(logger)
  60. }
  61. func (d *Driver) SetLogger(logger Logger) {
  62. d.log = optionalLogger{logger}
  63. }
  64. // NewConnector creates a new connector from a DSN.
  65. // The returned connector may be used with sql.OpenDB.
  66. func NewConnector(dsn string) (*Connector, error) {
  67. params, err := parseConnectParams(dsn)
  68. if err != nil {
  69. return nil, err
  70. }
  71. c := &Connector{
  72. params: params,
  73. driver: driverInstanceNoProcess,
  74. }
  75. return c, nil
  76. }
  77. // Connector holds the parsed DSN and is ready to make a new connection
  78. // at any time.
  79. //
  80. // In the future, settings that cannot be passed through a string DSN
  81. // may be set directly on the connector.
  82. type Connector struct {
  83. params connectParams
  84. driver *Driver
  85. // SessionInitSQL is executed after marking a given session to be reset.
  86. // When not present, the next query will still reset the session to the
  87. // database defaults.
  88. //
  89. // When present the connection will immediately mark the session to
  90. // be reset, then execute the SessionInitSQL text to setup the session
  91. // that may be different from the base database defaults.
  92. //
  93. // For Example, the application relies on the following defaults
  94. // but is not allowed to set them at the database system level.
  95. //
  96. // SET XACT_ABORT ON;
  97. // SET TEXTSIZE -1;
  98. // SET ANSI_NULLS ON;
  99. // SET LOCK_TIMEOUT 10000;
  100. //
  101. // SessionInitSQL should not attempt to manually call sp_reset_connection.
  102. // This will happen at the TDS layer.
  103. //
  104. // SessionInitSQL is optional. The session will be reset even if
  105. // SessionInitSQL is empty.
  106. SessionInitSQL string
  107. // Dialer sets a custom dialer for all network operations.
  108. // If Dialer is not set, normal net dialers are used.
  109. Dialer Dialer
  110. }
  111. type Dialer interface {
  112. DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
  113. }
  114. func (c *Connector) getDialer(p *connectParams) Dialer {
  115. if c != nil && c.Dialer != nil {
  116. return c.Dialer
  117. }
  118. return createDialer(p)
  119. }
  120. type Conn struct {
  121. connector *Connector
  122. sess *tdsSession
  123. transactionCtx context.Context
  124. resetSession bool
  125. processQueryText bool
  126. connectionGood bool
  127. outs map[string]interface{}
  128. returnStatus *ReturnStatus
  129. }
  130. func (c *Conn) setReturnStatus(s ReturnStatus) {
  131. if c.returnStatus == nil {
  132. return
  133. }
  134. *c.returnStatus = s
  135. }
  136. func (c *Conn) checkBadConn(err error) error {
  137. // this is a hack to address Issue #275
  138. // we set connectionGood flag to false if
  139. // error indicates that connection is not usable
  140. // but we return actual error instead of ErrBadConn
  141. // this will cause connection to stay in a pool
  142. // but next request to this connection will return ErrBadConn
  143. // it might be possible to revise this hack after
  144. // https://github.com/golang/go/issues/20807
  145. // is implemented
  146. switch err {
  147. case nil:
  148. return nil
  149. case io.EOF:
  150. c.connectionGood = false
  151. return driver.ErrBadConn
  152. case driver.ErrBadConn:
  153. // It is an internal programming error if driver.ErrBadConn
  154. // is ever passed to this function. driver.ErrBadConn should
  155. // only ever be returned in response to a *mssql.Conn.connectionGood == false
  156. // check in the external facing API.
  157. panic("driver.ErrBadConn in checkBadConn. This should not happen.")
  158. }
  159. switch err.(type) {
  160. case net.Error:
  161. c.connectionGood = false
  162. return err
  163. case StreamError:
  164. c.connectionGood = false
  165. return err
  166. default:
  167. return err
  168. }
  169. }
  170. func (c *Conn) clearOuts() {
  171. c.outs = nil
  172. }
  173. func (c *Conn) simpleProcessResp(ctx context.Context) error {
  174. tokchan := make(chan tokenStruct, 5)
  175. go processResponse(ctx, c.sess, tokchan, c.outs)
  176. c.clearOuts()
  177. for tok := range tokchan {
  178. switch token := tok.(type) {
  179. case doneStruct:
  180. if token.isError() {
  181. return c.checkBadConn(token.getError())
  182. }
  183. case error:
  184. return c.checkBadConn(token)
  185. }
  186. }
  187. return nil
  188. }
  189. func (c *Conn) Commit() error {
  190. if !c.connectionGood {
  191. return driver.ErrBadConn
  192. }
  193. if err := c.sendCommitRequest(); err != nil {
  194. return c.checkBadConn(err)
  195. }
  196. return c.simpleProcessResp(c.transactionCtx)
  197. }
  198. func (c *Conn) sendCommitRequest() error {
  199. headers := []headerStruct{
  200. {hdrtype: dataStmHdrTransDescr,
  201. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  202. }
  203. reset := c.resetSession
  204. c.resetSession = false
  205. if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
  206. if c.sess.logFlags&logErrors != 0 {
  207. c.sess.log.Printf("Failed to send CommitXact with %v", err)
  208. }
  209. c.connectionGood = false
  210. return fmt.Errorf("Faild to send CommitXact: %v", err)
  211. }
  212. return nil
  213. }
  214. func (c *Conn) Rollback() error {
  215. if !c.connectionGood {
  216. return driver.ErrBadConn
  217. }
  218. if err := c.sendRollbackRequest(); err != nil {
  219. return c.checkBadConn(err)
  220. }
  221. return c.simpleProcessResp(c.transactionCtx)
  222. }
  223. func (c *Conn) sendRollbackRequest() error {
  224. headers := []headerStruct{
  225. {hdrtype: dataStmHdrTransDescr,
  226. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  227. }
  228. reset := c.resetSession
  229. c.resetSession = false
  230. if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
  231. if c.sess.logFlags&logErrors != 0 {
  232. c.sess.log.Printf("Failed to send RollbackXact with %v", err)
  233. }
  234. c.connectionGood = false
  235. return fmt.Errorf("Failed to send RollbackXact: %v", err)
  236. }
  237. return nil
  238. }
  239. func (c *Conn) Begin() (driver.Tx, error) {
  240. return c.begin(context.Background(), isolationUseCurrent)
  241. }
  242. func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
  243. if !c.connectionGood {
  244. return nil, driver.ErrBadConn
  245. }
  246. err = c.sendBeginRequest(ctx, tdsIsolation)
  247. if err != nil {
  248. return nil, c.checkBadConn(err)
  249. }
  250. tx, err = c.processBeginResponse(ctx)
  251. if err != nil {
  252. return nil, c.checkBadConn(err)
  253. }
  254. return
  255. }
  256. func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
  257. c.transactionCtx = ctx
  258. headers := []headerStruct{
  259. {hdrtype: dataStmHdrTransDescr,
  260. data: transDescrHdr{0, 1}.pack()},
  261. }
  262. reset := c.resetSession
  263. c.resetSession = false
  264. if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
  265. if c.sess.logFlags&logErrors != 0 {
  266. c.sess.log.Printf("Failed to send BeginXact with %v", err)
  267. }
  268. c.connectionGood = false
  269. return fmt.Errorf("Failed to send BeginXact: %v", err)
  270. }
  271. return nil
  272. }
  273. func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
  274. if err := c.simpleProcessResp(ctx); err != nil {
  275. return nil, err
  276. }
  277. // successful BEGINXACT request will return sess.tranid
  278. // for started transaction
  279. return c, nil
  280. }
  281. func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
  282. params, err := parseConnectParams(dsn)
  283. if err != nil {
  284. return nil, err
  285. }
  286. return d.connect(ctx, nil, params)
  287. }
  288. // connect to the server, using the provided context for dialing only.
  289. func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
  290. sess, err := connect(ctx, c, d.log, params)
  291. if err != nil {
  292. // main server failed, try fail-over partner
  293. if params.failOverPartner == "" {
  294. return nil, err
  295. }
  296. params.host = params.failOverPartner
  297. if params.failOverPort != 0 {
  298. params.port = params.failOverPort
  299. }
  300. sess, err = connect(ctx, c, d.log, params)
  301. if err != nil {
  302. // fail-over partner also failed, now fail
  303. return nil, err
  304. }
  305. }
  306. conn := &Conn{
  307. connector: c,
  308. sess: sess,
  309. transactionCtx: context.Background(),
  310. processQueryText: d.processQueryText,
  311. connectionGood: true,
  312. }
  313. conn.sess.log = d.log
  314. return conn, nil
  315. }
  316. func (c *Conn) Close() error {
  317. return c.sess.buf.transport.Close()
  318. }
  319. type Stmt struct {
  320. c *Conn
  321. query string
  322. paramCount int
  323. notifSub *queryNotifSub
  324. }
  325. type queryNotifSub struct {
  326. msgText string
  327. options string
  328. timeout uint32
  329. }
  330. func (c *Conn) Prepare(query string) (driver.Stmt, error) {
  331. if !c.connectionGood {
  332. return nil, driver.ErrBadConn
  333. }
  334. if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
  335. return c.prepareCopyIn(context.Background(), query)
  336. }
  337. return c.prepareContext(context.Background(), query)
  338. }
  339. func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
  340. paramCount := -1
  341. if c.processQueryText {
  342. query, paramCount = parseParams(query)
  343. }
  344. return &Stmt{c, query, paramCount, nil}, nil
  345. }
  346. func (s *Stmt) Close() error {
  347. return nil
  348. }
  349. func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
  350. to := uint32(timeout / time.Second)
  351. if to < 1 {
  352. to = 1
  353. }
  354. s.notifSub = &queryNotifSub{id, options, to}
  355. }
  356. func (s *Stmt) NumInput() int {
  357. return s.paramCount
  358. }
  359. func (s *Stmt) sendQuery(args []namedValue) (err error) {
  360. headers := []headerStruct{
  361. {hdrtype: dataStmHdrTransDescr,
  362. data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
  363. }
  364. if s.notifSub != nil {
  365. headers = append(headers,
  366. headerStruct{
  367. hdrtype: dataStmHdrQueryNotif,
  368. data: queryNotifHdr{
  369. s.notifSub.msgText,
  370. s.notifSub.options,
  371. s.notifSub.timeout,
  372. }.pack(),
  373. })
  374. }
  375. conn := s.c
  376. // no need to check number of parameters here, it is checked by database/sql
  377. if conn.sess.logFlags&logSQL != 0 {
  378. conn.sess.log.Println(s.query)
  379. }
  380. if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
  381. for i := 0; i < len(args); i++ {
  382. if len(args[i].Name) > 0 {
  383. s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
  384. } else {
  385. s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
  386. }
  387. }
  388. }
  389. reset := conn.resetSession
  390. conn.resetSession = false
  391. if len(args) == 0 {
  392. if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
  393. if conn.sess.logFlags&logErrors != 0 {
  394. conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
  395. }
  396. conn.connectionGood = false
  397. return fmt.Errorf("failed to send SQL Batch: %v", err)
  398. }
  399. } else {
  400. proc := sp_ExecuteSql
  401. var params []param
  402. if isProc(s.query) {
  403. proc.name = s.query
  404. params, _, err = s.makeRPCParams(args, 0)
  405. if err != nil {
  406. return
  407. }
  408. } else {
  409. var decls []string
  410. params, decls, err = s.makeRPCParams(args, 2)
  411. if err != nil {
  412. return
  413. }
  414. params[0] = makeStrParam(s.query)
  415. params[1] = makeStrParam(strings.Join(decls, ","))
  416. }
  417. if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
  418. if conn.sess.logFlags&logErrors != 0 {
  419. conn.sess.log.Printf("Failed to send Rpc with %v", err)
  420. }
  421. conn.connectionGood = false
  422. return fmt.Errorf("Failed to send RPC: %v", err)
  423. }
  424. }
  425. return
  426. }
  427. // isProc takes the query text in s and determines if it is a stored proc name
  428. // or SQL text.
  429. func isProc(s string) bool {
  430. if len(s) == 0 {
  431. return false
  432. }
  433. const (
  434. outside = iota
  435. text
  436. escaped
  437. )
  438. st := outside
  439. var rn1, rPrev rune
  440. for _, r := range s {
  441. rPrev = rn1
  442. rn1 = r
  443. switch r {
  444. // No newlines or string sequences.
  445. case '\n', '\r', '\'', ';':
  446. return false
  447. }
  448. switch st {
  449. case outside:
  450. switch {
  451. case unicode.IsSpace(r):
  452. return false
  453. case r == '[':
  454. st = escaped
  455. continue
  456. case r == ']' && rPrev == ']':
  457. st = escaped
  458. continue
  459. case unicode.IsLetter(r):
  460. st = text
  461. }
  462. case text:
  463. switch {
  464. case r == '.':
  465. st = outside
  466. continue
  467. case unicode.IsSpace(r):
  468. return false
  469. }
  470. case escaped:
  471. switch {
  472. case r == ']':
  473. st = outside
  474. continue
  475. }
  476. }
  477. }
  478. return true
  479. }
  480. func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]param, []string, error) {
  481. var err error
  482. params := make([]param, len(args)+offset)
  483. decls := make([]string, len(args))
  484. for i, val := range args {
  485. params[i+offset], err = s.makeParam(val.Value)
  486. if err != nil {
  487. return nil, nil, err
  488. }
  489. var name string
  490. if len(val.Name) > 0 {
  491. name = "@" + val.Name
  492. } else {
  493. name = fmt.Sprintf("@p%d", val.Ordinal)
  494. }
  495. params[i+offset].Name = name
  496. decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
  497. }
  498. return params, decls, nil
  499. }
  500. type namedValue struct {
  501. Name string
  502. Ordinal int
  503. Value driver.Value
  504. }
  505. func convertOldArgs(args []driver.Value) []namedValue {
  506. list := make([]namedValue, len(args))
  507. for i, v := range args {
  508. list[i] = namedValue{
  509. Ordinal: i + 1,
  510. Value: v,
  511. }
  512. }
  513. return list
  514. }
  515. func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
  516. return s.queryContext(context.Background(), convertOldArgs(args))
  517. }
  518. func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
  519. if !s.c.connectionGood {
  520. return nil, driver.ErrBadConn
  521. }
  522. if err = s.sendQuery(args); err != nil {
  523. return nil, s.c.checkBadConn(err)
  524. }
  525. return s.processQueryResponse(ctx)
  526. }
  527. func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
  528. tokchan := make(chan tokenStruct, 5)
  529. ctx, cancel := context.WithCancel(ctx)
  530. go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
  531. s.c.clearOuts()
  532. // process metadata
  533. var cols []columnStruct
  534. loop:
  535. for tok := range tokchan {
  536. switch token := tok.(type) {
  537. // By ignoring DONE token we effectively
  538. // skip empty result-sets.
  539. // This improves results in queries like that:
  540. // set nocount on; select 1
  541. // see TestIgnoreEmptyResults test
  542. //case doneStruct:
  543. //break loop
  544. case []columnStruct:
  545. cols = token
  546. break loop
  547. case doneStruct:
  548. if token.isError() {
  549. return nil, s.c.checkBadConn(token.getError())
  550. }
  551. case ReturnStatus:
  552. s.c.setReturnStatus(token)
  553. case error:
  554. return nil, s.c.checkBadConn(token)
  555. }
  556. }
  557. res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
  558. return
  559. }
  560. func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
  561. return s.exec(context.Background(), convertOldArgs(args))
  562. }
  563. func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
  564. if !s.c.connectionGood {
  565. return nil, driver.ErrBadConn
  566. }
  567. if err = s.sendQuery(args); err != nil {
  568. return nil, s.c.checkBadConn(err)
  569. }
  570. if res, err = s.processExec(ctx); err != nil {
  571. return nil, s.c.checkBadConn(err)
  572. }
  573. return
  574. }
  575. func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
  576. tokchan := make(chan tokenStruct, 5)
  577. go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
  578. s.c.clearOuts()
  579. var rowCount int64
  580. for token := range tokchan {
  581. switch token := token.(type) {
  582. case doneInProcStruct:
  583. if token.Status&doneCount != 0 {
  584. rowCount += int64(token.RowCount)
  585. }
  586. case doneStruct:
  587. if token.Status&doneCount != 0 {
  588. rowCount += int64(token.RowCount)
  589. }
  590. if token.isError() {
  591. return nil, token.getError()
  592. }
  593. case ReturnStatus:
  594. s.c.setReturnStatus(token)
  595. case error:
  596. return nil, token
  597. }
  598. }
  599. return &Result{s.c, rowCount}, nil
  600. }
  601. type Rows struct {
  602. stmt *Stmt
  603. cols []columnStruct
  604. tokchan chan tokenStruct
  605. nextCols []columnStruct
  606. cancel func()
  607. }
  608. func (rc *Rows) Close() error {
  609. rc.cancel()
  610. for _ = range rc.tokchan {
  611. }
  612. rc.tokchan = nil
  613. return nil
  614. }
  615. func (rc *Rows) Columns() (res []string) {
  616. res = make([]string, len(rc.cols))
  617. for i, col := range rc.cols {
  618. res[i] = col.ColName
  619. }
  620. return
  621. }
  622. func (rc *Rows) Next(dest []driver.Value) error {
  623. if !rc.stmt.c.connectionGood {
  624. return driver.ErrBadConn
  625. }
  626. if rc.nextCols != nil {
  627. return io.EOF
  628. }
  629. for tok := range rc.tokchan {
  630. switch tokdata := tok.(type) {
  631. case []columnStruct:
  632. rc.nextCols = tokdata
  633. return io.EOF
  634. case []interface{}:
  635. for i := range dest {
  636. dest[i] = tokdata[i]
  637. }
  638. return nil
  639. case doneStruct:
  640. if tokdata.isError() {
  641. return rc.stmt.c.checkBadConn(tokdata.getError())
  642. }
  643. case error:
  644. return rc.stmt.c.checkBadConn(tokdata)
  645. }
  646. }
  647. return io.EOF
  648. }
  649. func (rc *Rows) HasNextResultSet() bool {
  650. return rc.nextCols != nil
  651. }
  652. func (rc *Rows) NextResultSet() error {
  653. rc.cols = rc.nextCols
  654. rc.nextCols = nil
  655. if rc.cols == nil {
  656. return io.EOF
  657. }
  658. return nil
  659. }
  660. // It should return
  661. // the value type that can be used to scan types into. For example, the database
  662. // column type "bigint" this should return "reflect.TypeOf(int64(0))".
  663. func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
  664. return makeGoLangScanType(r.cols[index].ti)
  665. }
  666. // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
  667. // database system type name without the length. Type names should be uppercase.
  668. // Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
  669. // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
  670. // "TIMESTAMP".
  671. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
  672. return makeGoLangTypeName(r.cols[index].ti)
  673. }
  674. // RowsColumnTypeLength may be implemented by Rows. It should return the length
  675. // of the column type if the column is a variable length type. If the column is
  676. // not a variable length type ok should return false.
  677. // If length is not limited other than system limits, it should return math.MaxInt64.
  678. // The following are examples of returned values for various types:
  679. // TEXT (math.MaxInt64, true)
  680. // varchar(10) (10, true)
  681. // nvarchar(10) (10, true)
  682. // decimal (0, false)
  683. // int (0, false)
  684. // bytea(30) (30, true)
  685. func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
  686. return makeGoLangTypeLength(r.cols[index].ti)
  687. }
  688. // It should return
  689. // the precision and scale for decimal types. If not applicable, ok should be false.
  690. // The following are examples of returned values for various types:
  691. // decimal(38, 4) (38, 4, true)
  692. // int (0, 0, false)
  693. // decimal (math.MaxInt64, math.MaxInt64, true)
  694. func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
  695. return makeGoLangTypePrecisionScale(r.cols[index].ti)
  696. }
  697. // The nullable value should
  698. // be true if it is known the column may be null, or false if the column is known
  699. // to be not nullable.
  700. // If the column nullability is unknown, ok should be false.
  701. func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
  702. nullable = r.cols[index].Flags&colFlagNullable != 0
  703. ok = true
  704. return
  705. }
  706. func makeStrParam(val string) (res param) {
  707. res.ti.TypeId = typeNVarChar
  708. res.buffer = str2ucs2(val)
  709. res.ti.Size = len(res.buffer)
  710. return
  711. }
  712. func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
  713. if val == nil {
  714. res.ti.TypeId = typeNull
  715. res.buffer = nil
  716. res.ti.Size = 0
  717. return
  718. }
  719. switch val := val.(type) {
  720. case int64:
  721. res.ti.TypeId = typeIntN
  722. res.buffer = make([]byte, 8)
  723. res.ti.Size = 8
  724. binary.LittleEndian.PutUint64(res.buffer, uint64(val))
  725. case sql.NullInt64:
  726. // only null values should be getting here
  727. res.ti.TypeId = typeIntN
  728. res.ti.Size = 8
  729. res.buffer = []byte{}
  730. case float64:
  731. res.ti.TypeId = typeFltN
  732. res.ti.Size = 8
  733. res.buffer = make([]byte, 8)
  734. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
  735. case sql.NullFloat64:
  736. // only null values should be getting here
  737. res.ti.TypeId = typeFltN
  738. res.ti.Size = 8
  739. res.buffer = []byte{}
  740. case []byte:
  741. res.ti.TypeId = typeBigVarBin
  742. res.ti.Size = len(val)
  743. res.buffer = val
  744. case string:
  745. res = makeStrParam(val)
  746. case sql.NullString:
  747. // only null values should be getting here
  748. res.ti.TypeId = typeNVarChar
  749. res.buffer = nil
  750. res.ti.Size = 8000
  751. case bool:
  752. res.ti.TypeId = typeBitN
  753. res.ti.Size = 1
  754. res.buffer = make([]byte, 1)
  755. if val {
  756. res.buffer[0] = 1
  757. }
  758. case sql.NullBool:
  759. // only null values should be getting here
  760. res.ti.TypeId = typeBitN
  761. res.ti.Size = 1
  762. res.buffer = []byte{}
  763. case time.Time:
  764. if s.c.sess.loginAck.TDSVersion >= verTDS73 {
  765. res.ti.TypeId = typeDateTimeOffsetN
  766. res.ti.Scale = 7
  767. res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
  768. res.ti.Size = len(res.buffer)
  769. } else {
  770. res.ti.TypeId = typeDateTimeN
  771. res.buffer = encodeDateTime(val)
  772. res.ti.Size = len(res.buffer)
  773. }
  774. default:
  775. return s.makeParamExtra(val)
  776. }
  777. return
  778. }
  779. type Result struct {
  780. c *Conn
  781. rowsAffected int64
  782. }
  783. func (r *Result) RowsAffected() (int64, error) {
  784. return r.rowsAffected, nil
  785. }
  786. func (r *Result) LastInsertId() (int64, error) {
  787. s, err := r.c.Prepare("select cast(@@identity as bigint)")
  788. if err != nil {
  789. return 0, err
  790. }
  791. defer s.Close()
  792. rows, err := s.Query(nil)
  793. if err != nil {
  794. return 0, err
  795. }
  796. defer rows.Close()
  797. dest := make([]driver.Value, 1)
  798. err = rows.Next(dest)
  799. if err != nil {
  800. return 0, err
  801. }
  802. if dest[0] == nil {
  803. return -1, errors.New("There is no generated identity value")
  804. }
  805. lastInsertId := dest[0].(int64)
  806. return lastInsertId, nil
  807. }
  808. var _ driver.Pinger = &Conn{}
  809. // Ping is used to check if the remote server is available and satisfies the Pinger interface.
  810. func (c *Conn) Ping(ctx context.Context) error {
  811. if !c.connectionGood {
  812. return driver.ErrBadConn
  813. }
  814. stmt := &Stmt{c, `select 1;`, 0, nil}
  815. _, err := stmt.ExecContext(ctx, nil)
  816. return err
  817. }
  818. var _ driver.ConnBeginTx = &Conn{}
  819. // BeginTx satisfies ConnBeginTx.
  820. func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  821. if !c.connectionGood {
  822. return nil, driver.ErrBadConn
  823. }
  824. if opts.ReadOnly {
  825. return nil, errors.New("Read-only transactions are not supported")
  826. }
  827. var tdsIsolation isoLevel
  828. switch sql.IsolationLevel(opts.Isolation) {
  829. case sql.LevelDefault:
  830. tdsIsolation = isolationUseCurrent
  831. case sql.LevelReadUncommitted:
  832. tdsIsolation = isolationReadUncommited
  833. case sql.LevelReadCommitted:
  834. tdsIsolation = isolationReadCommited
  835. case sql.LevelWriteCommitted:
  836. return nil, errors.New("LevelWriteCommitted isolation level is not supported")
  837. case sql.LevelRepeatableRead:
  838. tdsIsolation = isolationRepeatableRead
  839. case sql.LevelSnapshot:
  840. tdsIsolation = isolationSnapshot
  841. case sql.LevelSerializable:
  842. tdsIsolation = isolationSerializable
  843. case sql.LevelLinearizable:
  844. return nil, errors.New("LevelLinearizable isolation level is not supported")
  845. default:
  846. return nil, errors.New("Isolation level is not supported or unknown")
  847. }
  848. return c.begin(ctx, tdsIsolation)
  849. }
  850. func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  851. if !c.connectionGood {
  852. return nil, driver.ErrBadConn
  853. }
  854. if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
  855. return c.prepareCopyIn(ctx, query)
  856. }
  857. return c.prepareContext(ctx, query)
  858. }
  859. func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  860. if !s.c.connectionGood {
  861. return nil, driver.ErrBadConn
  862. }
  863. list := make([]namedValue, len(args))
  864. for i, nv := range args {
  865. list[i] = namedValue(nv)
  866. }
  867. return s.queryContext(ctx, list)
  868. }
  869. func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  870. if !s.c.connectionGood {
  871. return nil, driver.ErrBadConn
  872. }
  873. list := make([]namedValue, len(args))
  874. for i, nv := range args {
  875. list[i] = namedValue(nv)
  876. }
  877. return s.exec(ctx, list)
  878. }