Marcus Efraimsson 7 lat temu
rodzic
commit
a2eaf3954a
34 zmienionych plików z 411 dodań i 272 usunięć
  1. 6 3
      Gopkg.lock
  2. 1 1
      Gopkg.toml
  3. 11 3
      vendor/github.com/denisenkom/go-mssqldb/buf.go
  4. 51 54
      vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
  5. 13 12
      vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go
  6. 0 39
      vendor/github.com/denisenkom/go-mssqldb/collation.go
  7. 4 4
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go
  8. 20 0
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go
  9. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go
  10. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go
  11. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go
  12. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go
  13. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go
  14. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go
  15. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go
  16. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go
  17. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go
  18. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go
  19. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go
  20. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go
  21. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go
  22. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go
  23. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go
  24. 1 1
      vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go
  25. 140 82
      vendor/github.com/denisenkom/go-mssqldb/mssql.go
  26. 50 0
      vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go
  27. 9 9
      vendor/github.com/denisenkom/go-mssqldb/mssql_go18.go
  28. 14 3
      vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go
  29. 1 1
      vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go
  30. 11 7
      vendor/github.com/denisenkom/go-mssqldb/net.go
  31. 2 2
      vendor/github.com/denisenkom/go-mssqldb/rpc.go
  32. 31 22
      vendor/github.com/denisenkom/go-mssqldb/tds.go
  33. 6 7
      vendor/github.com/denisenkom/go-mssqldb/tran.go
  34. 25 7
      vendor/github.com/denisenkom/go-mssqldb/types.go

+ 6 - 3
Gopkg.lock

@@ -105,8 +105,11 @@
 
 [[projects]]
   name = "github.com/denisenkom/go-mssqldb"
-  packages = ["."]
-  revision = "ee492709d4324cdcb051d2ac266b77ddc380f5c5"
+  packages = [
+    ".",
+    "internal/cp"
+  ]
+  revision = "270bc3860bb94dd3a3ffd047377d746c5e276726"
 
 [[projects]]
   name = "github.com/fatih/color"
@@ -639,6 +642,6 @@
 [solve-meta]
   analyzer-name = "dep"
   analyzer-version = 1
-  inputs-digest = "d2f67abb94028a388f051164896bfb69b1ff3a7255d285dc4d78d298f4793383"
+  inputs-digest = "5e65aeace832f1b4be17e7ff5d5714513c40f31b94b885f64f98f2332968d7c6"
   solver-name = "gps-cdcl"
   solver-version = 1

+ 1 - 1
Gopkg.toml

@@ -200,4 +200,4 @@ ignored = [
 
 [[constraint]]
   name = "github.com/denisenkom/go-mssqldb"
-  revision = "ee492709d4324cdcb051d2ac266b77ddc380f5c5"
+  revision = "270bc3860bb94dd3a3ffd047377d746c5e276726"

+ 11 - 3
vendor/github.com/denisenkom/go-mssqldb/buf.go

@@ -115,15 +115,23 @@ func (w *tdsBuffer) WriteByte(b byte) error {
 	return nil
 }
 
-func (w *tdsBuffer) BeginPacket(packetType packetType) {
-	w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
+func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) {
+	status := byte(0)
+	if resetSession {
+		switch packetType {
+		// Reset session can only be set on the following packet types.
+		case packSQLBatch, packRPCRequest, packTransMgrReq:
+			status = 0x8
+		}
+	}
+	w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket.
 	w.wpos = 8
 	w.wPacketSeq = 1
 	w.wPacketType = packetType
 }
 
 func (w *tdsBuffer) FinishPacket() error {
-	w.wbuf[1] = 1 // Mark this as the last packet in the message.
+	w.wbuf[1] |= 1 // Mark this as the last packet in the message.
 	return w.flush()
 }
 

+ 51 - 54
vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go

@@ -12,8 +12,14 @@ import (
 	"time"
 )
 
-type MssqlBulk struct {
-	cn          *MssqlConn
+type Bulk struct {
+	// ctx is used only for AddRow and Done methods.
+	// This could be removed if AddRow and Done accepted
+	// a ctx field as well, which is available with the
+	// database/sql call.
+	ctx context.Context
+
+	cn          *Conn
 	metadata    []columnStruct
 	bulkColumns []columnStruct
 	columnsName []string
@@ -21,10 +27,10 @@ type MssqlBulk struct {
 	numRows     int
 
 	headerSent bool
-	Options    MssqlBulkOptions
+	Options    BulkOptions
 	Debug      bool
 }
-type MssqlBulkOptions struct {
+type BulkOptions struct {
 	CheckConstraints  bool
 	FireTriggers      bool
 	KeepNulls         bool
@@ -36,15 +42,21 @@ type MssqlBulkOptions struct {
 
 type DataValue interface{}
 
-func (cn *MssqlConn) CreateBulk(table string, columns []string) (_ *MssqlBulk) {
-	b := MssqlBulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
+func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
+	b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
 	b.Debug = false
 	return &b
 }
 
-func (b *MssqlBulk) sendBulkCommand() (err error) {
+func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
+	b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
+	b.Debug = false
+	return &b
+}
+
+func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
 	//get table columns info
-	err = b.getMetadata()
+	err = b.getMetadata(ctx)
 	if err != nil {
 		return err
 	}
@@ -114,13 +126,13 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
 
 	query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
 
-	stmt, err := b.cn.Prepare(query)
+	stmt, err := b.cn.PrepareContext(ctx, query)
 	if err != nil {
 		return fmt.Errorf("Prepare failed: %s", err.Error())
 	}
 	b.dlogf(query)
 
-	_, err = stmt.Exec(nil)
+	_, err = stmt.(*Stmt).ExecContext(ctx, nil)
 	if err != nil {
 		return err
 	}
@@ -128,9 +140,9 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
 	b.headerSent = true
 
 	var buf = b.cn.sess.buf
-	buf.BeginPacket(packBulkLoadBCP)
+	buf.BeginPacket(packBulkLoadBCP, false)
 
-	// send the columns metadata
+	// Send the columns metadata.
 	columnMetadata := b.createColMetadata()
 	_, err = buf.Write(columnMetadata)
 
@@ -139,9 +151,9 @@ func (b *MssqlBulk) sendBulkCommand() (err error) {
 
 // AddRow immediately writes the row to the destination table.
 // The arguments are the row values in the order they were specified.
-func (b *MssqlBulk) AddRow(row []interface{}) (err error) {
+func (b *Bulk) AddRow(row []interface{}) (err error) {
 	if !b.headerSent {
-		err = b.sendBulkCommand()
+		err = b.sendBulkCommand(b.ctx)
 		if err != nil {
 			return
 		}
@@ -166,7 +178,7 @@ func (b *MssqlBulk) AddRow(row []interface{}) (err error) {
 	return
 }
 
-func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) {
+func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
 	buf := new(bytes.Buffer)
 	buf.WriteByte(byte(tokenRow))
 
@@ -196,7 +208,7 @@ func (b *MssqlBulk) makeRowData(row []interface{}) ([]byte, error) {
 	return buf.Bytes(), nil
 }
 
-func (b *MssqlBulk) Done() (rowcount int64, err error) {
+func (b *Bulk) Done() (rowcount int64, err error) {
 	if b.headerSent == false {
 		//no rows had been sent
 		return 0, nil
@@ -216,7 +228,7 @@ func (b *MssqlBulk) Done() (rowcount int64, err error) {
 	buf.FinishPacket()
 
 	tokchan := make(chan tokenStruct, 5)
-	go processResponse(context.Background(), b.cn.sess, tokchan, nil)
+	go processResponse(b.ctx, b.cn.sess, tokchan, nil)
 
 	var rowCount int64
 	for token := range tokchan {
@@ -235,7 +247,7 @@ func (b *MssqlBulk) Done() (rowcount int64, err error) {
 	return rowCount, nil
 }
 
-func (b *MssqlBulk) createColMetadata() []byte {
+func (b *Bulk) createColMetadata() []byte {
 	buf := new(bytes.Buffer)
 	buf.WriteByte(byte(tokenColMetadata))                              // token
 	binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
@@ -267,64 +279,40 @@ func (b *MssqlBulk) createColMetadata() []byte {
 	return buf.Bytes()
 }
 
-func (b *MssqlBulk) getMetadata() (err error) {
-	stmt, err := b.cn.Prepare("SET FMTONLY ON")
+func (b *Bulk) getMetadata(ctx context.Context) (err error) {
+	stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
 	if err != nil {
 		return
 	}
 
-	_, err = stmt.Exec(nil)
+	_, err = stmt.ExecContext(ctx, nil)
 	if err != nil {
 		return
 	}
 
-	//get columns info
-	stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
+	// Get columns info.
+	stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
 	if err != nil {
 		return
 	}
-	stmt2 := stmt.(*MssqlStmt)
-	cols, err := stmt2.QueryMeta()
+	rows, err := stmt.QueryContext(ctx, nil)
 	if err != nil {
-		return fmt.Errorf("get columns info failed: %v", err.Error())
+		return fmt.Errorf("get columns info failed: %v", err)
 	}
-	b.metadata = cols
+	b.metadata = rows.(*Rows).cols
 
 	if b.Debug {
 		for _, col := range b.metadata {
 			b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
 				col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
-				col.Flags, col.ti.Collation.lcidAndFlags)
+				col.Flags, col.ti.Collation.LcidAndFlags)
 		}
 	}
 
-	return nil
-}
-
-// QueryMeta is almost the same as MssqlStmt.Query, but returns all the columns info.
-func (s *MssqlStmt) QueryMeta() (cols []columnStruct, err error) {
-	if err = s.sendQuery(nil); err != nil {
-		return
-	}
-	tokchan := make(chan tokenStruct, 5)
-	go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
-	s.c.clearOuts()
-loop:
-	for tok := range tokchan {
-		switch token := tok.(type) {
-		case doneStruct:
-			break loop
-		case []columnStruct:
-			cols = token
-			break loop
-		case error:
-			return nil, s.c.checkBadConn(token)
-		}
-	}
-	return cols, nil
+	return rows.Close()
 }
 
-func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
+func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
 	res.ti.Size = col.ti.Size
 	res.ti.TypeId = col.ti.TypeId
 
@@ -592,6 +580,15 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
 			err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
 			return
 		}
+	case typeGuid:
+		switch val := val.(type) {
+		case []byte:
+			res.ti.Size = len(val)
+			res.buffer = val
+		default:
+			err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
+			return
+		}
 
 	default:
 		err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
@@ -600,7 +597,7 @@ func (b *MssqlBulk) makeParam(val DataValue, col columnStruct) (res Param, err e
 
 }
 
-func (b *MssqlBulk) dlogf(format string, v ...interface{}) {
+func (b *Bulk) dlogf(format string, v ...interface{}) {
 	if b.Debug {
 		b.cn.sess.log.Printf(format, v...)
 	}

+ 13 - 12
vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go

@@ -1,37 +1,38 @@
 package mssql
 
 import (
+	"context"
 	"database/sql/driver"
 	"encoding/json"
 	"errors"
 )
 
 type copyin struct {
-	cn       *MssqlConn
-	bulkcopy *MssqlBulk
+	cn       *Conn
+	bulkcopy *Bulk
 	closed   bool
 }
 
-type SerializableBulkConfig struct {
+type serializableBulkConfig struct {
 	TableName   string
 	ColumnsName []string
-	Options     MssqlBulkOptions
+	Options     BulkOptions
 }
 
-func (d *MssqlDriver) OpenConnection(dsn string) (*MssqlConn, error) {
-	return d.open(dsn)
+func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
+	return d.open(context.Background(), dsn)
 }
 
-func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
+func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) {
 	config_json := query[11:]
 
-	bulkconfig := SerializableBulkConfig{}
+	bulkconfig := serializableBulkConfig{}
 	err = json.Unmarshal([]byte(config_json), &bulkconfig)
 	if err != nil {
 		return
 	}
 
-	bulkcopy := c.CreateBulk(bulkconfig.TableName, bulkconfig.ColumnsName)
+	bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName)
 	bulkcopy.Options = bulkconfig.Options
 
 	ci := &copyin{
@@ -42,8 +43,8 @@ func (c *MssqlConn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
 	return ci, nil
 }
 
-func CopyIn(table string, options MssqlBulkOptions, columns ...string) string {
-	bulkconfig := &SerializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
+func CopyIn(table string, options BulkOptions, columns ...string) string {
+	bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
 
 	config_json, err := json.Marshal(bulkconfig)
 	if err != nil {
@@ -60,7 +61,7 @@ func (ci *copyin) NumInput() int {
 }
 
 func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
-	return nil, errors.New("ErrNotSupported")
+	panic("should never be called")
 }
 
 func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {

+ 0 - 39
vendor/github.com/denisenkom/go-mssqldb/collation.go

@@ -1,39 +0,0 @@
-package mssql
-
-import (
-	"encoding/binary"
-	"io"
-)
-
-// http://msdn.microsoft.com/en-us/library/dd340437.aspx
-
-type collation struct {
-	lcidAndFlags uint32
-	sortId       uint8
-}
-
-func (c collation) getLcid() uint32 {
-	return c.lcidAndFlags & 0x000fffff
-}
-
-func (c collation) getFlags() uint32 {
-	return (c.lcidAndFlags & 0x0ff00000) >> 20
-}
-
-func (c collation) getVersion() uint32 {
-	return (c.lcidAndFlags & 0xf0000000) >> 28
-}
-
-func readCollation(r *tdsBuffer) (res collation) {
-	res.lcidAndFlags = r.uint32()
-	res.sortId = r.byte()
-	return
-}
-
-func writeCollation(w io.Writer, col collation) (err error) {
-	if err = binary.Write(w, binary.LittleEndian, col.lcidAndFlags); err != nil {
-		return
-	}
-	err = binary.Write(w, binary.LittleEndian, col.sortId)
-	return
-}

+ 4 - 4
vendor/github.com/denisenkom/go-mssqldb/charset.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go

@@ -1,14 +1,14 @@
-package mssql
+package cp
 
 type charsetMap struct {
 	sb [256]rune    // single byte runes, -1 for a double byte character lead byte
 	db map[int]rune // double byte runes
 }
 
-func collation2charset(col collation) *charsetMap {
+func collation2charset(col Collation) *charsetMap {
 	// http://msdn.microsoft.com/en-us/library/ms144250.aspx
 	// http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx
-	switch col.sortId {
+	switch col.SortId {
 	case 30, 31, 32, 33, 34:
 		return cp437
 	case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61:
@@ -86,7 +86,7 @@ func collation2charset(col collation) *charsetMap {
 	return cp1252
 }
 
-func charset2utf8(col collation, s []byte) string {
+func CharsetToUTF8(col Collation, s []byte) string {
 	cm := collation2charset(col)
 	if cm == nil {
 		return string(s)

+ 20 - 0
vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go

@@ -0,0 +1,20 @@
+package cp
+
+// http://msdn.microsoft.com/en-us/library/dd340437.aspx
+
+type Collation struct {
+	LcidAndFlags uint32
+	SortId       uint8
+}
+
+func (c Collation) getLcid() uint32 {
+	return c.LcidAndFlags & 0x000fffff
+}
+
+func (c Collation) getFlags() uint32 {
+	return (c.LcidAndFlags & 0x0ff00000) >> 20
+}
+
+func (c Collation) getVersion() uint32 {
+	return (c.LcidAndFlags & 0xf0000000) >> 28
+}

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1250.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1250 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1251.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1251 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1252.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1252 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1253.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1253 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1254.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1254 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1255.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1255 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1256.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1256 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1257.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1257 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp1258.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp1258 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp437.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp437 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp850.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp850 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp874.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp874 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp932.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp932 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp936.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp936 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp949.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp949 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/cp950.go → vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go

@@ -1,4 +1,4 @@
-package mssql
+package cp
 
 var cp950 *charsetMap = &charsetMap{
 	sb: [256]rune{

+ 140 - 82
vendor/github.com/denisenkom/go-mssqldb/mssql.go

@@ -15,20 +15,20 @@ import (
 	"time"
 )
 
-var driverInstance = &MssqlDriver{processQueryText: true}
-var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
+var driverInstance = &Driver{processQueryText: true}
+var driverInstanceNoProcess = &Driver{processQueryText: false}
 
 func init() {
 	sql.Register("mssql", driverInstance)
 	sql.Register("sqlserver", driverInstanceNoProcess)
 	createDialer = func(p *connectParams) dialer {
-		return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
+		return tcpDialer{&net.Dialer{KeepAlive: p.keepAlive}}
 	}
 }
 
 // Abstract the dialer for testing and for non-TCP based connections.
 type dialer interface {
-	Dial(addr string) (net.Conn, error)
+	Dial(ctx context.Context, addr string) (net.Conn, error)
 }
 
 var createDialer func(p *connectParams) dialer
@@ -37,28 +37,75 @@ type tcpDialer struct {
 	nd *net.Dialer
 }
 
-func (d tcpDialer) Dial(addr string) (net.Conn, error) {
-	return d.nd.Dial("tcp", addr)
+func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
+	return d.nd.DialContext(ctx, "tcp", addr)
 }
 
-type MssqlDriver struct {
+type Driver struct {
 	log optionalLogger
 
 	processQueryText bool
 }
 
+// OpenConnector opens a new connector. Useful to dial with a context.
+func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
+	params, err := parseConnectParams(dsn)
+	if err != nil {
+		return nil, err
+	}
+	return &Connector{
+		params: params,
+		driver: d,
+	}, nil
+}
+
+func (d *Driver) Open(dsn string) (driver.Conn, error) {
+	return d.open(context.Background(), dsn)
+}
+
 func SetLogger(logger Logger) {
 	driverInstance.SetLogger(logger)
 	driverInstanceNoProcess.SetLogger(logger)
 }
 
-func (d *MssqlDriver) SetLogger(logger Logger) {
+func (d *Driver) SetLogger(logger Logger) {
 	d.log = optionalLogger{logger}
 }
 
-type MssqlConn struct {
+// Connector holds the parsed DSN and is ready to make a new connection
+// at any time.
+//
+// In the future, settings that cannot be passed through a string DSN
+// may be set directly on the connector.
+type Connector struct {
+	params connectParams
+	driver *Driver
+
+	// ResetSQL is executed after marking a given connection to be reset.
+	// When not present, the next query will be reset to the database
+	// defaults.
+	// When present the connection will immediately mark the connection to
+	// be reset, then execute the ResetSQL text to setup the session
+	// that may be different from the base database defaults.
+	//
+	// For Example, the application relies on the following defaults
+	// but is not allowed to set them at the database system level.
+	//
+	//    SET XACT_ABORT ON;
+	//    SET TEXTSIZE -1;
+	//    SET ANSI_NULLS ON;
+	//    SET LOCK_TIMEOUT 10000;
+	//
+	// ResetSQL should not attempt to manually call sp_reset_connection.
+	// This will happen at the TDS layer.
+	ResetSQL string
+}
+
+type Conn struct {
+	connector      *Connector
 	sess           *tdsSession
 	transactionCtx context.Context
+	resetSession   bool
 
 	processQueryText bool
 	connectionGood   bool
@@ -66,7 +113,7 @@ type MssqlConn struct {
 	outs map[string]interface{}
 }
 
-func (c *MssqlConn) checkBadConn(err error) error {
+func (c *Conn) checkBadConn(err error) error {
 	// this is a hack to address Issue #275
 	// we set connectionGood flag to false if
 	// error indicates that connection is not usable
@@ -81,11 +128,12 @@ func (c *MssqlConn) checkBadConn(err error) error {
 	case nil:
 		return nil
 	case io.EOF:
+		c.connectionGood = false
 		return driver.ErrBadConn
 	case driver.ErrBadConn:
 		// It is an internal programming error if driver.ErrBadConn
 		// is ever passed to this function. driver.ErrBadConn should
-		// only ever be returned in response to a *MssqlConn.connectionGood == false
+		// only ever be returned in response to a *mssql.Conn.connectionGood == false
 		// check in the external facing API.
 		panic("driver.ErrBadConn in checkBadConn. This should not happen.")
 	}
@@ -102,11 +150,11 @@ func (c *MssqlConn) checkBadConn(err error) error {
 	}
 }
 
-func (c *MssqlConn) clearOuts() {
+func (c *Conn) clearOuts() {
 	c.outs = nil
 }
 
-func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
+func (c *Conn) simpleProcessResp(ctx context.Context) error {
 	tokchan := make(chan tokenStruct, 5)
 	go processResponse(ctx, c.sess, tokchan, c.outs)
 	c.clearOuts()
@@ -123,7 +171,7 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
 	return nil
 }
 
-func (c *MssqlConn) Commit() error {
+func (c *Conn) Commit() error {
 	if !c.connectionGood {
 		return driver.ErrBadConn
 	}
@@ -133,12 +181,14 @@ func (c *MssqlConn) Commit() error {
 	return c.simpleProcessResp(c.transactionCtx)
 }
 
-func (c *MssqlConn) sendCommitRequest() error {
+func (c *Conn) sendCommitRequest() error {
 	headers := []headerStruct{
 		{hdrtype: dataStmHdrTransDescr,
 			data: transDescrHdr{c.sess.tranid, 1}.pack()},
 	}
-	if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+	reset := c.resetSession
+	c.resetSession = false
+	if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
 		if c.sess.logFlags&logErrors != 0 {
 			c.sess.log.Printf("Failed to send CommitXact with %v", err)
 		}
@@ -148,7 +198,7 @@ func (c *MssqlConn) sendCommitRequest() error {
 	return nil
 }
 
-func (c *MssqlConn) Rollback() error {
+func (c *Conn) Rollback() error {
 	if !c.connectionGood {
 		return driver.ErrBadConn
 	}
@@ -158,12 +208,14 @@ func (c *MssqlConn) Rollback() error {
 	return c.simpleProcessResp(c.transactionCtx)
 }
 
-func (c *MssqlConn) sendRollbackRequest() error {
+func (c *Conn) sendRollbackRequest() error {
 	headers := []headerStruct{
 		{hdrtype: dataStmHdrTransDescr,
 			data: transDescrHdr{c.sess.tranid, 1}.pack()},
 	}
-	if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+	reset := c.resetSession
+	c.resetSession = false
+	if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
 		if c.sess.logFlags&logErrors != 0 {
 			c.sess.log.Printf("Failed to send RollbackXact with %v", err)
 		}
@@ -173,11 +225,11 @@ func (c *MssqlConn) sendRollbackRequest() error {
 	return nil
 }
 
-func (c *MssqlConn) Begin() (driver.Tx, error) {
+func (c *Conn) Begin() (driver.Tx, error) {
 	return c.begin(context.Background(), isolationUseCurrent)
 }
 
-func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
+func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
 	if !c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
@@ -192,13 +244,15 @@ func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver
 	return
 }
 
-func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
+func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
 	c.transactionCtx = ctx
 	headers := []headerStruct{
 		{hdrtype: dataStmHdrTransDescr,
 			data: transDescrHdr{0, 1}.pack()},
 	}
-	if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
+	reset := c.resetSession
+	c.resetSession = false
+	if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
 		if c.sess.logFlags&logErrors != 0 {
 			c.sess.log.Printf("Failed to send BeginXact with %v", err)
 		}
@@ -208,7 +262,7 @@ func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel)
 	return nil
 }
 
-func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
+func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
 	if err := c.simpleProcessResp(ctx); err != nil {
 		return nil, err
 	}
@@ -217,17 +271,17 @@ func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error)
 	return c, nil
 }
 
-func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
-	return d.open(dsn)
-}
-
-func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
+func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
 	params, err := parseConnectParams(dsn)
 	if err != nil {
 		return nil, err
 	}
+	return d.connect(ctx, params)
+}
 
-	sess, err := connect(d.log, params)
+// connect to the server, using the provided context for dialing only.
+func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
+	sess, err := connect(ctx, d.log, params)
 	if err != nil {
 		// main server failed, try fail-over partner
 		if params.failOverPartner == "" {
@@ -239,29 +293,30 @@ func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
 			params.port = params.failOverPort
 		}
 
-		sess, err = connect(d.log, params)
+		sess, err = connect(ctx, d.log, params)
 		if err != nil {
 			// fail-over partner also failed, now fail
 			return nil, err
 		}
 	}
 
-	conn := &MssqlConn{
+	conn := &Conn{
 		sess:             sess,
 		transactionCtx:   context.Background(),
 		processQueryText: d.processQueryText,
 		connectionGood:   true,
 	}
 	conn.sess.log = d.log
+
 	return conn, nil
 }
 
-func (c *MssqlConn) Close() error {
+func (c *Conn) Close() error {
 	return c.sess.buf.transport.Close()
 }
 
-type MssqlStmt struct {
-	c          *MssqlConn
+type Stmt struct {
+	c          *Conn
 	query      string
 	paramCount int
 	notifSub   *queryNotifSub
@@ -273,30 +328,29 @@ type queryNotifSub struct {
 	timeout uint32
 }
 
-func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
+func (c *Conn) Prepare(query string) (driver.Stmt, error) {
 	if !c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
 	if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
-		return c.prepareCopyIn(query)
+		return c.prepareCopyIn(context.Background(), query)
 	}
-
 	return c.prepareContext(context.Background(), query)
 }
 
-func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
+func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
 	paramCount := -1
 	if c.processQueryText {
 		query, paramCount = parseParams(query)
 	}
-	return &MssqlStmt{c, query, paramCount, nil}, nil
+	return &Stmt{c, query, paramCount, nil}, nil
 }
 
-func (s *MssqlStmt) Close() error {
+func (s *Stmt) Close() error {
 	return nil
 }
 
-func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
+func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
 	to := uint32(timeout / time.Second)
 	if to < 1 {
 		to = 1
@@ -304,11 +358,11 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati
 	s.notifSub = &queryNotifSub{id, options, to}
 }
 
-func (s *MssqlStmt) NumInput() int {
+func (s *Stmt) NumInput() int {
 	return s.paramCount
 }
 
-func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
+func (s *Stmt) sendQuery(args []namedValue) (err error) {
 	headers := []headerStruct{
 		{hdrtype: dataStmHdrTransDescr,
 			data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
@@ -326,11 +380,13 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
 			})
 	}
 
+	conn := s.c
+
 	// no need to check number of parameters here, it is checked by database/sql
-	if s.c.sess.logFlags&logSQL != 0 {
-		s.c.sess.log.Println(s.query)
+	if conn.sess.logFlags&logSQL != 0 {
+		conn.sess.log.Println(s.query)
 	}
-	if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
+	if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
 		for i := 0; i < len(args); i++ {
 			if len(args[i].Name) > 0 {
 				s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
@@ -338,14 +394,16 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
 				s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
 			}
 		}
-
 	}
+
+	reset := conn.resetSession
+	conn.resetSession = false
 	if len(args) == 0 {
-		if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
-			if s.c.sess.logFlags&logErrors != 0 {
-				s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
+		if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
+			if conn.sess.logFlags&logErrors != 0 {
+				conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
 			}
-			s.c.connectionGood = false
+			conn.connectionGood = false
 			return fmt.Errorf("failed to send SQL Batch: %v", err)
 		}
 	} else {
@@ -363,11 +421,11 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
 			params[0] = makeStrParam(s.query)
 			params[1] = makeStrParam(strings.Join(decls, ","))
 		}
-		if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
-			if s.c.sess.logFlags&logErrors != 0 {
-				s.c.sess.log.Printf("Failed to send Rpc with %v", err)
+		if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
+			if conn.sess.logFlags&logErrors != 0 {
+				conn.sess.log.Printf("Failed to send Rpc with %v", err)
 			}
-			s.c.connectionGood = false
+			conn.connectionGood = false
 			return fmt.Errorf("Failed to send RPC: %v", err)
 		}
 	}
@@ -386,7 +444,7 @@ func isProc(s string) bool {
 	return !strings.ContainsAny(s, " \t\n\r;")
 }
 
-func (s *MssqlStmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
+func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
 	var err error
 	params := make([]Param, len(args)+offset)
 	decls := make([]string, len(args))
@@ -424,11 +482,11 @@ func convertOldArgs(args []driver.Value) []namedValue {
 	return list
 }
 
-func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
 	return s.queryContext(context.Background(), convertOldArgs(args))
 }
 
-func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
+func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
 	if !s.c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
@@ -438,7 +496,7 @@ func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows d
 	return s.processQueryResponse(ctx)
 }
 
-func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
+func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
 	tokchan := make(chan tokenStruct, 5)
 	ctx, cancel := context.WithCancel(ctx)
 	go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
@@ -466,15 +524,15 @@ loop:
 			return nil, s.c.checkBadConn(token)
 		}
 	}
-	res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
+	res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
 	return
 }
 
-func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
 	return s.exec(context.Background(), convertOldArgs(args))
 }
 
-func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
+func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
 	if !s.c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
@@ -487,7 +545,7 @@ func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Res
 	return
 }
 
-func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
+func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
 	tokchan := make(chan tokenStruct, 5)
 	go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
 	s.c.clearOuts()
@@ -509,11 +567,11 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err
 			return nil, token
 		}
 	}
-	return &MssqlResult{s.c, rowCount}, nil
+	return &Result{s.c, rowCount}, nil
 }
 
-type MssqlRows struct {
-	stmt    *MssqlStmt
+type Rows struct {
+	stmt    *Stmt
 	cols    []columnStruct
 	tokchan chan tokenStruct
 
@@ -522,7 +580,7 @@ type MssqlRows struct {
 	cancel func()
 }
 
-func (rc *MssqlRows) Close() error {
+func (rc *Rows) Close() error {
 	rc.cancel()
 	for range rc.tokchan {
 	}
@@ -530,7 +588,7 @@ func (rc *MssqlRows) Close() error {
 	return nil
 }
 
-func (rc *MssqlRows) Columns() (res []string) {
+func (rc *Rows) Columns() (res []string) {
 	res = make([]string, len(rc.cols))
 	for i, col := range rc.cols {
 		res[i] = col.ColName
@@ -538,7 +596,7 @@ func (rc *MssqlRows) Columns() (res []string) {
 	return
 }
 
-func (rc *MssqlRows) Next(dest []driver.Value) error {
+func (rc *Rows) Next(dest []driver.Value) error {
 	if !rc.stmt.c.connectionGood {
 		return driver.ErrBadConn
 	}
@@ -566,11 +624,11 @@ func (rc *MssqlRows) Next(dest []driver.Value) error {
 	return io.EOF
 }
 
-func (rc *MssqlRows) HasNextResultSet() bool {
+func (rc *Rows) HasNextResultSet() bool {
 	return rc.nextCols != nil
 }
 
-func (rc *MssqlRows) NextResultSet() error {
+func (rc *Rows) NextResultSet() error {
 	rc.cols = rc.nextCols
 	rc.nextCols = nil
 	if rc.cols == nil {
@@ -582,7 +640,7 @@ func (rc *MssqlRows) NextResultSet() error {
 // It should return
 // the value type that can be used to scan types into. For example, the database
 // column type "bigint" this should return "reflect.TypeOf(int64(0))".
-func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
+func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
 	return makeGoLangScanType(r.cols[index].ti)
 }
 
@@ -591,7 +649,7 @@ func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
 // Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
 // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
 // "TIMESTAMP".
-func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
+func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
 	return makeGoLangTypeName(r.cols[index].ti)
 }
 
@@ -606,7 +664,7 @@ func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
 //   decimal       (0, false)
 //   int           (0, false)
 //   bytea(30)     (30, true)
-func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
+func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
 	return makeGoLangTypeLength(r.cols[index].ti)
 }
 
@@ -616,7 +674,7 @@ func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
 //   decimal(38, 4)    (38, 4, true)
 //   int               (0, 0, false)
 //   decimal           (math.MaxInt64, math.MaxInt64, true)
-func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
+func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
 	return makeGoLangTypePrecisionScale(r.cols[index].ti)
 }
 
@@ -624,7 +682,7 @@ func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
 // be true if it is known the column may be null, or false if the column is known
 // to be not nullable.
 // If the column nullability is unknown, ok should be false.
-func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
+func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
 	nullable = r.cols[index].Flags&colFlagNullable != 0
 	ok = true
 	return
@@ -637,7 +695,7 @@ func makeStrParam(val string) (res Param) {
 	return
 }
 
-func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
+func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
 	if val == nil {
 		res.ti.TypeId = typeNull
 		res.buffer = nil
@@ -706,16 +764,16 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
 	return
 }
 
-type MssqlResult struct {
-	c            *MssqlConn
+type Result struct {
+	c            *Conn
 	rowsAffected int64
 }
 
-func (r *MssqlResult) RowsAffected() (int64, error) {
+func (r *Result) RowsAffected() (int64, error) {
 	return r.rowsAffected, nil
 }
 
-func (r *MssqlResult) LastInsertId() (int64, error) {
+func (r *Result) LastInsertId() (int64, error) {
 	s, err := r.c.Prepare("select cast(@@identity as bigint)")
 	if err != nil {
 		return 0, err

+ 50 - 0
vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go

@@ -0,0 +1,50 @@
+// +build go1.10
+
+package mssql
+
+import (
+	"context"
+	"database/sql/driver"
+)
+
+var _ driver.Connector = &Connector{}
+var _ driver.SessionResetter = &Conn{}
+
+func (c *Conn) ResetSession(ctx context.Context) error {
+	if !c.connectionGood {
+		return driver.ErrBadConn
+	}
+	c.resetSession = true
+
+	if c.connector == nil || len(c.connector.ResetSQL) == 0 {
+		return nil
+	}
+
+	s, err := c.prepareContext(ctx, c.connector.ResetSQL)
+	if err != nil {
+		return driver.ErrBadConn
+	}
+	_, err = s.exec(ctx, nil)
+	if err != nil {
+		return driver.ErrBadConn
+	}
+
+	return nil
+}
+
+// Connect to the server and return a TDS connection.
+func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
+	conn, err := c.driver.connect(ctx, c.params)
+	if conn != nil {
+		conn.connector = c
+	}
+	if err == nil {
+		err = conn.ResetSession(ctx)
+	}
+	return conn, err
+}
+
+// Driver underlying the Connector.
+func (c *Connector) Driver() driver.Driver {
+	return c.driver
+}

+ 9 - 9
vendor/github.com/denisenkom/go-mssqldb/mssql_go18.go

@@ -10,22 +10,22 @@ import (
 	"strings"
 )
 
-var _ driver.Pinger = &MssqlConn{}
+var _ driver.Pinger = &Conn{}
 
 // Ping is used to check if the remote server is available and satisfies the Pinger interface.
-func (c *MssqlConn) Ping(ctx context.Context) error {
+func (c *Conn) Ping(ctx context.Context) error {
 	if !c.connectionGood {
 		return driver.ErrBadConn
 	}
-	stmt := &MssqlStmt{c, `select 1;`, 0, nil}
+	stmt := &Stmt{c, `select 1;`, 0, nil}
 	_, err := stmt.ExecContext(ctx, nil)
 	return err
 }
 
-var _ driver.ConnBeginTx = &MssqlConn{}
+var _ driver.ConnBeginTx = &Conn{}
 
 // BeginTx satisfies ConnBeginTx.
-func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
 	if !c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
@@ -57,18 +57,18 @@ func (c *MssqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
 	return c.begin(ctx, tdsIsolation)
 }
 
-func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
 	if !c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
 	if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
-		return c.prepareCopyIn(query)
+		return c.prepareCopyIn(ctx, query)
 	}
 
 	return c.prepareContext(ctx, query)
 }
 
-func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
 	if !s.c.connectionGood {
 		return nil, driver.ErrBadConn
 	}
@@ -79,7 +79,7 @@ func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
 	return s.queryContext(ctx, list)
 }
 
-func (s *MssqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
 	if !s.c.connectionGood {
 		return nil, driver.ErrBadConn
 	}

+ 14 - 3
vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go

@@ -9,9 +9,20 @@ import (
 	// "github.com/cockroachdb/apd"
 )
 
-var _ driver.NamedValueChecker = &MssqlConn{}
+// Type alias provided for compibility.
+//
+// Deprecated: users should transition to the new names when possible.
+type MssqlDriver = Driver
+type MssqlBulk = Bulk
+type MssqlBulkOptions = BulkOptions
+type MssqlConn = Conn
+type MssqlResult = Result
+type MssqlRows = Rows
+type MssqlStmt = Stmt
 
-func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error {
+var _ driver.NamedValueChecker = &Conn{}
+
+func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
 	switch v := nv.Value.(type) {
 	case sql.Out:
 		if c.outs == nil {
@@ -41,7 +52,7 @@ func (c *MssqlConn) CheckNamedValue(nv *driver.NamedValue) error {
 	}
 }
 
-func (s *MssqlStmt) makeParamExtra(val driver.Value) (res Param, err error) {
+func (s *Stmt) makeParamExtra(val driver.Value) (res Param, err error) {
 	switch val := val.(type) {
 	case sql.Out:
 		res, err = s.makeParam(val.Dest)

+ 1 - 1
vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go

@@ -7,6 +7,6 @@ import (
 	"fmt"
 )
 
-func (s *MssqlStmt) makeParamExtra(val driver.Value) (Param, error) {
+func (s *Stmt) makeParamExtra(val driver.Value) (Param, error) {
 	return Param{}, fmt.Errorf("mssql: unknown type for %T", val)
 }

+ 11 - 7
vendor/github.com/denisenkom/go-mssqldb/net.go

@@ -48,9 +48,11 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
 		n, err = c.buf.Read(b)
 		return
 	}
-	err = c.c.SetDeadline(time.Now().Add(c.timeout))
-	if err != nil {
-		return
+	if c.timeout > 0 {
+		err = c.c.SetDeadline(time.Now().Add(c.timeout))
+		if err != nil {
+			return
+		}
 	}
 	return c.c.Read(b)
 }
@@ -58,7 +60,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
 func (c *timeoutConn) Write(b []byte) (n int, err error) {
 	if c.buf != nil {
 		if !c.packetPending {
-			c.buf.BeginPacket(packPrelogin)
+			c.buf.BeginPacket(packPrelogin, false)
 			c.packetPending = true
 		}
 		n, err = c.buf.Write(b)
@@ -67,9 +69,11 @@ func (c *timeoutConn) Write(b []byte) (n int, err error) {
 		}
 		return
 	}
-	err = c.c.SetDeadline(time.Now().Add(c.timeout))
-	if err != nil {
-		return
+	if c.timeout > 0 {
+		err = c.c.SetDeadline(time.Now().Add(c.timeout))
+		if err != nil {
+			return
+		}
 	}
 	return c.c.Write(b)
 }

+ 2 - 2
vendor/github.com/denisenkom/go-mssqldb/rpc.go

@@ -57,8 +57,8 @@ var (
 )
 
 // http://msdn.microsoft.com/en-us/library/dd357576.aspx
-func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param) (err error) {
-	buf.BeginPacket(packRPCRequest)
+func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param, resetSession bool) (err error) {
+	buf.BeginPacket(packRPCRequest, resetSession)
 	writeAllHeaders(buf, headers)
 	if len(proc.name) == 0 {
 		var idswitch uint16 = 0xffff

+ 31 - 22
vendor/github.com/denisenkom/go-mssqldb/tds.go

@@ -50,13 +50,17 @@ func parseInstances(msg []byte) map[string]map[string]string {
 	return results
 }
 
-func getInstances(address string) (map[string]map[string]string, error) {
-	conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
+func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
+	maxTime := 5 * time.Second
+	dialer := &net.Dialer{
+		Timeout: maxTime,
+	}
+	conn, err := dialer.DialContext(ctx, "udp", address+":1434")
 	if err != nil {
 		return nil, err
 	}
 	defer conn.Close()
-	conn.SetDeadline(time.Now().Add(5 * time.Second))
+	conn.SetDeadline(time.Now().Add(maxTime))
 	_, err = conn.Write([]byte{3})
 	if err != nil {
 		return nil, err
@@ -159,7 +163,7 @@ func (p KeySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
 func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
 	var err error
 
-	w.BeginPacket(packPrelogin)
+	w.BeginPacket(packPrelogin, false)
 	offset := uint16(5*len(fields) + 1)
 	keys := make(KeySlice, 0, len(fields))
 	for k := range fields {
@@ -349,7 +353,7 @@ func manglePassword(password string) []byte {
 
 // http://msdn.microsoft.com/en-us/library/dd304019.aspx
 func sendLogin(w *tdsBuffer, login login) error {
-	w.BeginPacket(packLogin7)
+	w.BeginPacket(packLogin7, false)
 	hostname := str2ucs2(login.HostName)
 	username := str2ucs2(login.UserName)
 	password := manglePassword(login.Password)
@@ -630,8 +634,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
 	return nil
 }
 
-func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
-	buf.BeginPacket(packSQLBatch)
+func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
+	buf.BeginPacket(packSQLBatch, resetSession)
 
 	if err = writeAllHeaders(buf, headers); err != nil {
 		return
@@ -647,7 +651,7 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err
 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
 // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
 func sendAttention(buf *tdsBuffer) error {
-	buf.BeginPacket(packAttention)
+	buf.BeginPacket(packAttention, false)
 	return buf.FinishPacket()
 }
 
@@ -935,13 +939,13 @@ func parseConnectParams(dsn string) (connectParams, error) {
 	strlog, ok := params["log"]
 	if ok {
 		var err error
-		p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
+		p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
 		if err != nil {
 			return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
 		}
 	}
 	server := params["server"]
-	parts := strings.SplitN(server, "\\", 2)
+	parts := strings.SplitN(server, `\`, 2)
 	p.host = parts[0]
 	if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
 		p.host = "localhost"
@@ -957,7 +961,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
 	strport, ok := params["port"]
 	if ok {
 		var err error
-		p.port, err = strconv.ParseUint(strport, 0, 16)
+		p.port, err = strconv.ParseUint(strport, 10, 16)
 		if err != nil {
 			f := "Invalid tcp port '%v': %v"
 			return p, fmt.Errorf(f, strport, err.Error())
@@ -993,7 +997,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
 	p.conn_timeout = 30 * time.Second
 	strconntimeout, ok := params["connection timeout"]
 	if ok {
-		timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
+		timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
 		if err != nil {
 			f := "Invalid connection timeout '%v': %v"
 			return p, fmt.Errorf(f, strconntimeout, err.Error())
@@ -1002,7 +1006,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
 	}
 	strdialtimeout, ok := params["dial timeout"]
 	if ok {
-		timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
+		timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
 		if err != nil {
 			f := "Invalid dial timeout '%v': %v"
 			return p, fmt.Errorf(f, strdialtimeout, err.Error())
@@ -1015,7 +1019,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
 	p.keepAlive = 30 * time.Second
 
 	if keepAlive, ok := params["keepalive"]; ok {
-		timeout, err := strconv.ParseUint(keepAlive, 0, 16)
+		timeout, err := strconv.ParseUint(keepAlive, 10, 64)
 		if err != nil {
 			f := "Invalid keepAlive value '%s': %s"
 			return p, fmt.Errorf(f, keepAlive, err.Error())
@@ -1109,7 +1113,7 @@ type auth interface {
 // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
 // list of IP addresses.  So if there is more than one, try them all and
 // use the first one that allows a connection.
-func dialConnection(p connectParams) (conn net.Conn, err error) {
+func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
 	var ips []net.IP
 	ips, err = net.LookupIP(p.host)
 	if err != nil {
@@ -1122,7 +1126,7 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
 	if len(ips) == 1 {
 		d := createDialer(&p)
 		addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
-		conn, err = d.Dial(addr)
+		conn, err = d.Dial(ctx, addr)
 
 	} else {
 		//Try Dials in parallel to avoid waiting for timeouts.
@@ -1133,7 +1137,7 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
 			go func(ip net.IP) {
 				d := createDialer(&p)
 				addr := net.JoinHostPort(ip.String(), portStr)
-				conn, err := d.Dial(addr)
+				conn, err := d.Dial(ctx, addr)
 				if err == nil {
 					connChan <- conn
 				} else {
@@ -1171,12 +1175,17 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
 	return conn, err
 }
 
-func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
-	res = nil
+func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
+	dialCtx := ctx
+	if p.dial_timeout > 0 {
+		var cancel func()
+		dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
+		defer cancel()
+	}
 	// if instance is specified use instance resolution service
 	if p.instance != "" {
 		p.instance = strings.ToUpper(p.instance)
-		instances, err := getInstances(p.host)
+		instances, err := getInstances(dialCtx, p.host)
 		if err != nil {
 			f := "Unable to get instances from Sql Server Browser on host %v: %v"
 			return nil, fmt.Errorf(f, p.host, err.Error())
@@ -1194,7 +1203,7 @@ func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
 	}
 
 initiate_connection:
-	conn, err := dialConnection(p)
+	conn, err := dialConnection(dialCtx, p)
 	if err != nil {
 		return nil, err
 	}
@@ -1334,7 +1343,7 @@ continue_login:
 		}
 	}
 	if sspi_msg != nil {
-		outbuf.BeginPacket(packSSPIMessage)
+		outbuf.BeginPacket(packSSPIMessage, false)
 		_, err = outbuf.Write(sspi_msg)
 		if err != nil {
 			return nil, err

+ 6 - 7
vendor/github.com/denisenkom/go-mssqldb/tran.go

@@ -28,9 +28,8 @@ const (
 	isolationSnapshot                = 5
 )
 
-func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
-	name string) (err error) {
-	buf.BeginPacket(packTransMgrReq)
+func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {
+	buf.BeginPacket(packTransMgrReq, resetSession)
 	writeAllHeaders(buf, headers)
 	var rqtype uint16 = tmBeginXact
 	err = binary.Write(buf, binary.LittleEndian, &rqtype)
@@ -52,8 +51,8 @@ const (
 	fBeginXact = 1
 )
 
-func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
-	buf.BeginPacket(packTransMgrReq)
+func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
+	buf.BeginPacket(packTransMgrReq, resetSession)
 	writeAllHeaders(buf, headers)
 	var rqtype uint16 = tmCommitXact
 	err := binary.Write(buf, binary.LittleEndian, &rqtype)
@@ -81,8 +80,8 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u
 	return buf.FinishPacket()
 }
 
-func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
-	buf.BeginPacket(packTransMgrReq)
+func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
+	buf.BeginPacket(packTransMgrReq, resetSession)
 	writeAllHeaders(buf, headers)
 	var rqtype uint16 = tmRollbackXact
 	err := binary.Write(buf, binary.LittleEndian, &rqtype)

+ 25 - 7
vendor/github.com/denisenkom/go-mssqldb/types.go

@@ -9,6 +9,8 @@ import (
 	"reflect"
 	"strconv"
 	"time"
+
+	"github.com/denisenkom/go-mssqldb/internal/cp"
 )
 
 // fixed-length data types
@@ -79,7 +81,7 @@ type typeInfo struct {
 	Scale     uint8
 	Prec      uint8
 	Buffer    []byte
-	Collation collation
+	Collation cp.Collation
 	UdtInfo   udtInfo
 	XmlInfo   xmlInfo
 	Reader    func(ti *typeInfo, r *tdsBuffer) (res interface{})
@@ -487,6 +489,20 @@ func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
 	return
 }
 
+func readCollation(r *tdsBuffer) (res cp.Collation) {
+	res.LcidAndFlags = r.uint32()
+	res.SortId = r.byte()
+	return
+}
+
+func writeCollation(w io.Writer, col cp.Collation) (err error) {
+	if err = binary.Write(w, binary.LittleEndian, col.LcidAndFlags); err != nil {
+		return
+	}
+	err = binary.Write(w, binary.LittleEndian, col.SortId)
+	return
+}
+
 // reads variant value
 // http://msdn.microsoft.com/en-us/library/dd303302.aspx
 func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
@@ -848,8 +864,8 @@ func dateTime2(t time.Time) (days int32, ns int64) {
 	return
 }
 
-func decodeChar(col collation, buf []byte) string {
-	return charset2utf8(col, buf)
+func decodeChar(col cp.Collation, buf []byte) string {
+	return cp.CharsetToUTF8(col, buf)
 }
 
 func decodeUcs2(buf []byte) string {
@@ -922,7 +938,7 @@ func makeGoLangScanType(ti typeInfo) reflect.Type {
 		return reflect.TypeOf(true)
 	case typeDecimalN, typeNumericN:
 		return reflect.TypeOf([]byte{})
-	case typeMoneyN:
+	case typeMoney, typeMoney4, typeMoneyN:
 		switch ti.Size {
 		case 4:
 			return reflect.TypeOf([]byte{})
@@ -1083,6 +1099,8 @@ func makeDecl(ti typeInfo) string {
 		return "ntext"
 	case typeUdt:
 		return ti.UdtInfo.TypeName
+	case typeGuid:
+		return "uniqueidentifier"
 	default:
 		panic(fmt.Sprintf("not implemented makeDecl for type %#x", ti.TypeId))
 	}
@@ -1140,7 +1158,7 @@ func makeGoLangTypeName(ti typeInfo) string {
 		return "BIT"
 	case typeDecimalN, typeNumericN:
 		return "DECIMAL"
-	case typeMoneyN:
+	case typeMoney, typeMoney4, typeMoneyN:
 		switch ti.Size {
 		case 4:
 			return "SMALLMONEY"
@@ -1247,7 +1265,7 @@ func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
 		return 0, false
 	case typeDecimalN, typeNumericN:
 		return 0, false
-	case typeMoneyN:
+	case typeMoney, typeMoney4, typeMoneyN:
 		switch ti.Size {
 		case 4:
 			return 0, false
@@ -1370,7 +1388,7 @@ func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
 		return 0, 0, false
 	case typeDecimalN, typeNumericN:
 		return int64(ti.Prec), int64(ti.Scale), true
-	case typeMoneyN:
+	case typeMoney, typeMoney4, typeMoneyN:
 		switch ti.Size {
 		case 4:
 			return 0, 0, false