|
|
@@ -15,20 +15,20 @@ import (
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-var driverInstance = &MssqlDriver{processQueryText: true}
|
|
|
-var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
|
|
|
+var driverInstance = &Driver{processQueryText: true}
|
|
|
+var driverInstanceNoProcess = &Driver{processQueryText: false}
|
|
|
|
|
|
func init() {
|
|
|
sql.Register("mssql", driverInstance)
|
|
|
sql.Register("sqlserver", driverInstanceNoProcess)
|
|
|
createDialer = func(p *connectParams) dialer {
|
|
|
- return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
|
|
|
+ return tcpDialer{&net.Dialer{KeepAlive: p.keepAlive}}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Abstract the dialer for testing and for non-TCP based connections.
|
|
|
type dialer interface {
|
|
|
- Dial(addr string) (net.Conn, error)
|
|
|
+ Dial(ctx context.Context, addr string) (net.Conn, error)
|
|
|
}
|
|
|
|
|
|
var createDialer func(p *connectParams) dialer
|
|
|
@@ -37,28 +37,75 @@ type tcpDialer struct {
|
|
|
nd *net.Dialer
|
|
|
}
|
|
|
|
|
|
-func (d tcpDialer) Dial(addr string) (net.Conn, error) {
|
|
|
- return d.nd.Dial("tcp", addr)
|
|
|
+func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
|
|
+ return d.nd.DialContext(ctx, "tcp", addr)
|
|
|
}
|
|
|
|
|
|
-type MssqlDriver struct {
|
|
|
+type Driver struct {
|
|
|
log optionalLogger
|
|
|
|
|
|
processQueryText bool
|
|
|
}
|
|
|
|
|
|
+// OpenConnector opens a new connector. Useful to dial with a context.
|
|
|
+func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
|
|
|
+ params, err := parseConnectParams(dsn)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return &Connector{
|
|
|
+ params: params,
|
|
|
+ driver: d,
|
|
|
+ }, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (d *Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
+ return d.open(context.Background(), dsn)
|
|
|
+}
|
|
|
+
|
|
|
func SetLogger(logger Logger) {
|
|
|
driverInstance.SetLogger(logger)
|
|
|
driverInstanceNoProcess.SetLogger(logger)
|
|
|
}
|
|
|
|
|
|
-func (d *MssqlDriver) SetLogger(logger Logger) {
|
|
|
+func (d *Driver) SetLogger(logger Logger) {
|
|
|
d.log = optionalLogger{logger}
|
|
|
}
|
|
|
|
|
|
-type MssqlConn struct {
|
|
|
+// Connector holds the parsed DSN and is ready to make a new connection
|
|
|
+// at any time.
|
|
|
+//
|
|
|
+// In the future, settings that cannot be passed through a string DSN
|
|
|
+// may be set directly on the connector.
|
|
|
+type Connector struct {
|
|
|
+ params connectParams
|
|
|
+ driver *Driver
|
|
|
+
|
|
|
+ // ResetSQL is executed after marking a given connection to be reset.
|
|
|
+ // When not present, the next query will be reset to the database
|
|
|
+ // defaults.
|
|
|
+ // When present the connection will immediately mark the connection to
|
|
|
+ // be reset, then execute the ResetSQL text to setup the session
|
|
|
+ // that may be different from the base database defaults.
|
|
|
+ //
|
|
|
+ // For Example, the application relies on the following defaults
|
|
|
+ // but is not allowed to set them at the database system level.
|
|
|
+ //
|
|
|
+ // SET XACT_ABORT ON;
|
|
|
+ // SET TEXTSIZE -1;
|
|
|
+ // SET ANSI_NULLS ON;
|
|
|
+ // SET LOCK_TIMEOUT 10000;
|
|
|
+ //
|
|
|
+ // ResetSQL should not attempt to manually call sp_reset_connection.
|
|
|
+ // This will happen at the TDS layer.
|
|
|
+ ResetSQL string
|
|
|
+}
|
|
|
+
|
|
|
+type Conn struct {
|
|
|
+ connector *Connector
|
|
|
sess *tdsSession
|
|
|
transactionCtx context.Context
|
|
|
+ resetSession bool
|
|
|
|
|
|
processQueryText bool
|
|
|
connectionGood bool
|
|
|
@@ -66,7 +113,7 @@ type MssqlConn struct {
|
|
|
outs map[string]interface{}
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) checkBadConn(err error) error {
|
|
|
+func (c *Conn) checkBadConn(err error) error {
|
|
|
// this is a hack to address Issue #275
|
|
|
// we set connectionGood flag to false if
|
|
|
// error indicates that connection is not usable
|
|
|
@@ -81,11 +128,12 @@ func (c *MssqlConn) checkBadConn(err error) error {
|
|
|
case nil:
|
|
|
return nil
|
|
|
case io.EOF:
|
|
|
+ c.connectionGood = false
|
|
|
return driver.ErrBadConn
|
|
|
case driver.ErrBadConn:
|
|
|
// It is an internal programming error if driver.ErrBadConn
|
|
|
// is ever passed to this function. driver.ErrBadConn should
|
|
|
- // only ever be returned in response to a *MssqlConn.connectionGood == false
|
|
|
+ // only ever be returned in response to a *mssql.Conn.connectionGood == false
|
|
|
// check in the external facing API.
|
|
|
panic("driver.ErrBadConn in checkBadConn. This should not happen.")
|
|
|
}
|
|
|
@@ -102,11 +150,11 @@ func (c *MssqlConn) checkBadConn(err error) error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) clearOuts() {
|
|
|
+func (c *Conn) clearOuts() {
|
|
|
c.outs = nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
|
|
|
+func (c *Conn) simpleProcessResp(ctx context.Context) error {
|
|
|
tokchan := make(chan tokenStruct, 5)
|
|
|
go processResponse(ctx, c.sess, tokchan, c.outs)
|
|
|
c.clearOuts()
|
|
|
@@ -123,7 +171,7 @@ func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) Commit() error {
|
|
|
+func (c *Conn) Commit() error {
|
|
|
if !c.connectionGood {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
@@ -133,12 +181,14 @@ func (c *MssqlConn) Commit() error {
|
|
|
return c.simpleProcessResp(c.transactionCtx)
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) sendCommitRequest() error {
|
|
|
+func (c *Conn) sendCommitRequest() error {
|
|
|
headers := []headerStruct{
|
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
|
|
}
|
|
|
- if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
|
|
+ reset := c.resetSession
|
|
|
+ c.resetSession = false
|
|
|
+ if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
|
|
|
if c.sess.logFlags&logErrors != 0 {
|
|
|
c.sess.log.Printf("Failed to send CommitXact with %v", err)
|
|
|
}
|
|
|
@@ -148,7 +198,7 @@ func (c *MssqlConn) sendCommitRequest() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) Rollback() error {
|
|
|
+func (c *Conn) Rollback() error {
|
|
|
if !c.connectionGood {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
@@ -158,12 +208,14 @@ func (c *MssqlConn) Rollback() error {
|
|
|
return c.simpleProcessResp(c.transactionCtx)
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) sendRollbackRequest() error {
|
|
|
+func (c *Conn) sendRollbackRequest() error {
|
|
|
headers := []headerStruct{
|
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
|
|
}
|
|
|
- if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
|
|
+ reset := c.resetSession
|
|
|
+ c.resetSession = false
|
|
|
+ if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
|
|
|
if c.sess.logFlags&logErrors != 0 {
|
|
|
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
|
|
|
}
|
|
|
@@ -173,11 +225,11 @@ func (c *MssqlConn) sendRollbackRequest() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) Begin() (driver.Tx, error) {
|
|
|
+func (c *Conn) Begin() (driver.Tx, error) {
|
|
|
return c.begin(context.Background(), isolationUseCurrent)
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
|
|
|
+func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
|
|
|
if !c.connectionGood {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
@@ -192,13 +244,15 @@ func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
|
|
|
+func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
|
|
|
c.transactionCtx = ctx
|
|
|
headers := []headerStruct{
|
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
|
data: transDescrHdr{0, 1}.pack()},
|
|
|
}
|
|
|
- if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
|
|
|
+ reset := c.resetSession
|
|
|
+ c.resetSession = false
|
|
|
+ if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
|
|
|
if c.sess.logFlags&logErrors != 0 {
|
|
|
c.sess.log.Printf("Failed to send BeginXact with %v", err)
|
|
|
}
|
|
|
@@ -208,7 +262,7 @@ func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
|
|
|
+func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
|
|
|
if err := c.simpleProcessResp(ctx); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -217,17 +271,17 @@ func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error)
|
|
|
return c, nil
|
|
|
}
|
|
|
|
|
|
-func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
|
|
|
- return d.open(dsn)
|
|
|
-}
|
|
|
-
|
|
|
-func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
|
|
|
+func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
|
|
|
params, err := parseConnectParams(dsn)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+ return d.connect(ctx, params)
|
|
|
+}
|
|
|
|
|
|
- sess, err := connect(d.log, params)
|
|
|
+// connect to the server, using the provided context for dialing only.
|
|
|
+func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
|
|
|
+ sess, err := connect(ctx, d.log, params)
|
|
|
if err != nil {
|
|
|
// main server failed, try fail-over partner
|
|
|
if params.failOverPartner == "" {
|
|
|
@@ -239,29 +293,30 @@ func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
|
|
|
params.port = params.failOverPort
|
|
|
}
|
|
|
|
|
|
- sess, err = connect(d.log, params)
|
|
|
+ sess, err = connect(ctx, d.log, params)
|
|
|
if err != nil {
|
|
|
// fail-over partner also failed, now fail
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- conn := &MssqlConn{
|
|
|
+ conn := &Conn{
|
|
|
sess: sess,
|
|
|
transactionCtx: context.Background(),
|
|
|
processQueryText: d.processQueryText,
|
|
|
connectionGood: true,
|
|
|
}
|
|
|
conn.sess.log = d.log
|
|
|
+
|
|
|
return conn, nil
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) Close() error {
|
|
|
+func (c *Conn) Close() error {
|
|
|
return c.sess.buf.transport.Close()
|
|
|
}
|
|
|
|
|
|
-type MssqlStmt struct {
|
|
|
- c *MssqlConn
|
|
|
+type Stmt struct {
|
|
|
+ c *Conn
|
|
|
query string
|
|
|
paramCount int
|
|
|
notifSub *queryNotifSub
|
|
|
@@ -273,30 +328,29 @@ type queryNotifSub struct {
|
|
|
timeout uint32
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
|
|
|
+func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
|
|
if !c.connectionGood {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
|
|
|
- return c.prepareCopyIn(query)
|
|
|
+ return c.prepareCopyIn(context.Background(), query)
|
|
|
}
|
|
|
-
|
|
|
return c.prepareContext(context.Background(), query)
|
|
|
}
|
|
|
|
|
|
-func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
|
|
|
+func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
|
|
|
paramCount := -1
|
|
|
if c.processQueryText {
|
|
|
query, paramCount = parseParams(query)
|
|
|
}
|
|
|
- return &MssqlStmt{c, query, paramCount, nil}, nil
|
|
|
+ return &Stmt{c, query, paramCount, nil}, nil
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) Close() error {
|
|
|
+func (s *Stmt) Close() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
|
|
|
+func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
|
|
|
to := uint32(timeout / time.Second)
|
|
|
if to < 1 {
|
|
|
to = 1
|
|
|
@@ -304,11 +358,11 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati
|
|
|
s.notifSub = &queryNotifSub{id, options, to}
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) NumInput() int {
|
|
|
+func (s *Stmt) NumInput() int {
|
|
|
return s.paramCount
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
|
|
|
+func (s *Stmt) sendQuery(args []namedValue) (err error) {
|
|
|
headers := []headerStruct{
|
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
|
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
|
|
|
@@ -326,11 +380,13 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+ conn := s.c
|
|
|
+
|
|
|
// no need to check number of parameters here, it is checked by database/sql
|
|
|
- if s.c.sess.logFlags&logSQL != 0 {
|
|
|
- s.c.sess.log.Println(s.query)
|
|
|
+ if conn.sess.logFlags&logSQL != 0 {
|
|
|
+ conn.sess.log.Println(s.query)
|
|
|
}
|
|
|
- if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
|
|
|
+ if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
|
|
|
for i := 0; i < len(args); i++ {
|
|
|
if len(args[i].Name) > 0 {
|
|
|
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
|
|
|
@@ -338,14 +394,16 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
|
|
|
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
}
|
|
|
+
|
|
|
+ reset := conn.resetSession
|
|
|
+ conn.resetSession = false
|
|
|
if len(args) == 0 {
|
|
|
- if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
|
|
|
- if s.c.sess.logFlags&logErrors != 0 {
|
|
|
- s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
|
|
|
+ if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
|
|
|
+ if conn.sess.logFlags&logErrors != 0 {
|
|
|
+ conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
|
|
|
}
|
|
|
- s.c.connectionGood = false
|
|
|
+ conn.connectionGood = false
|
|
|
return fmt.Errorf("failed to send SQL Batch: %v", err)
|
|
|
}
|
|
|
} else {
|
|
|
@@ -363,11 +421,11 @@ func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
|
|
|
params[0] = makeStrParam(s.query)
|
|
|
params[1] = makeStrParam(strings.Join(decls, ","))
|
|
|
}
|
|
|
- if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
|
|
|
- if s.c.sess.logFlags&logErrors != 0 {
|
|
|
- s.c.sess.log.Printf("Failed to send Rpc with %v", err)
|
|
|
+ if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
|
|
|
+ if conn.sess.logFlags&logErrors != 0 {
|
|
|
+ conn.sess.log.Printf("Failed to send Rpc with %v", err)
|
|
|
}
|
|
|
- s.c.connectionGood = false
|
|
|
+ conn.connectionGood = false
|
|
|
return fmt.Errorf("Failed to send RPC: %v", err)
|
|
|
}
|
|
|
}
|
|
|
@@ -386,7 +444,7 @@ func isProc(s string) bool {
|
|
|
return !strings.ContainsAny(s, " \t\n\r;")
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
|
|
|
+func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
|
|
|
var err error
|
|
|
params := make([]Param, len(args)+offset)
|
|
|
decls := make([]string, len(args))
|
|
|
@@ -424,11 +482,11 @@ func convertOldArgs(args []driver.Value) []namedValue {
|
|
|
return list
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
|
+func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
|
return s.queryContext(context.Background(), convertOldArgs(args))
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
|
|
|
+func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
|
|
|
if !s.c.connectionGood {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
@@ -438,7 +496,7 @@ func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (rows d
|
|
|
return s.processQueryResponse(ctx)
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
|
|
|
+func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
|
|
|
tokchan := make(chan tokenStruct, 5)
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
|
|
@@ -466,15 +524,15 @@ loop:
|
|
|
return nil, s.c.checkBadConn(token)
|
|
|
}
|
|
|
}
|
|
|
- res = &MssqlRows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
|
|
|
+ res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
|
+func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
|
return s.exec(context.Background(), convertOldArgs(args))
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
|
|
|
+func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
|
|
|
if !s.c.connectionGood {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
@@ -487,7 +545,7 @@ func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (res driver.Res
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
|
|
|
+func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
|
|
|
tokchan := make(chan tokenStruct, 5)
|
|
|
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
|
|
s.c.clearOuts()
|
|
|
@@ -509,11 +567,11 @@ func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err err
|
|
|
return nil, token
|
|
|
}
|
|
|
}
|
|
|
- return &MssqlResult{s.c, rowCount}, nil
|
|
|
+ return &Result{s.c, rowCount}, nil
|
|
|
}
|
|
|
|
|
|
-type MssqlRows struct {
|
|
|
- stmt *MssqlStmt
|
|
|
+type Rows struct {
|
|
|
+ stmt *Stmt
|
|
|
cols []columnStruct
|
|
|
tokchan chan tokenStruct
|
|
|
|
|
|
@@ -522,7 +580,7 @@ type MssqlRows struct {
|
|
|
cancel func()
|
|
|
}
|
|
|
|
|
|
-func (rc *MssqlRows) Close() error {
|
|
|
+func (rc *Rows) Close() error {
|
|
|
rc.cancel()
|
|
|
for range rc.tokchan {
|
|
|
}
|
|
|
@@ -530,7 +588,7 @@ func (rc *MssqlRows) Close() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (rc *MssqlRows) Columns() (res []string) {
|
|
|
+func (rc *Rows) Columns() (res []string) {
|
|
|
res = make([]string, len(rc.cols))
|
|
|
for i, col := range rc.cols {
|
|
|
res[i] = col.ColName
|
|
|
@@ -538,7 +596,7 @@ func (rc *MssqlRows) Columns() (res []string) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (rc *MssqlRows) Next(dest []driver.Value) error {
|
|
|
+func (rc *Rows) Next(dest []driver.Value) error {
|
|
|
if !rc.stmt.c.connectionGood {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
@@ -566,11 +624,11 @@ func (rc *MssqlRows) Next(dest []driver.Value) error {
|
|
|
return io.EOF
|
|
|
}
|
|
|
|
|
|
-func (rc *MssqlRows) HasNextResultSet() bool {
|
|
|
+func (rc *Rows) HasNextResultSet() bool {
|
|
|
return rc.nextCols != nil
|
|
|
}
|
|
|
|
|
|
-func (rc *MssqlRows) NextResultSet() error {
|
|
|
+func (rc *Rows) NextResultSet() error {
|
|
|
rc.cols = rc.nextCols
|
|
|
rc.nextCols = nil
|
|
|
if rc.cols == nil {
|
|
|
@@ -582,7 +640,7 @@ func (rc *MssqlRows) NextResultSet() error {
|
|
|
// It should return
|
|
|
// the value type that can be used to scan types into. For example, the database
|
|
|
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
|
|
|
-func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
|
|
|
+func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
|
|
return makeGoLangScanType(r.cols[index].ti)
|
|
|
}
|
|
|
|
|
|
@@ -591,7 +649,7 @@ func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
|
|
|
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
|
|
|
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
|
|
|
// "TIMESTAMP".
|
|
|
-func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
|
|
|
+func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
|
|
return makeGoLangTypeName(r.cols[index].ti)
|
|
|
}
|
|
|
|
|
|
@@ -606,7 +664,7 @@ func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
|
|
|
// decimal (0, false)
|
|
|
// int (0, false)
|
|
|
// bytea(30) (30, true)
|
|
|
-func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
|
|
|
+func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
|
|
return makeGoLangTypeLength(r.cols[index].ti)
|
|
|
}
|
|
|
|
|
|
@@ -616,7 +674,7 @@ func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
|
|
|
// decimal(38, 4) (38, 4, true)
|
|
|
// int (0, 0, false)
|
|
|
// decimal (math.MaxInt64, math.MaxInt64, true)
|
|
|
-func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
|
|
|
+func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
|
|
|
return makeGoLangTypePrecisionScale(r.cols[index].ti)
|
|
|
}
|
|
|
|
|
|
@@ -624,7 +682,7 @@ func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
|
|
|
// be true if it is known the column may be null, or false if the column is known
|
|
|
// to be not nullable.
|
|
|
// If the column nullability is unknown, ok should be false.
|
|
|
-func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
|
|
+func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
|
|
nullable = r.cols[index].Flags&colFlagNullable != 0
|
|
|
ok = true
|
|
|
return
|
|
|
@@ -637,7 +695,7 @@ func makeStrParam(val string) (res Param) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
|
|
+func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
|
|
|
if val == nil {
|
|
|
res.ti.TypeId = typeNull
|
|
|
res.buffer = nil
|
|
|
@@ -706,16 +764,16 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-type MssqlResult struct {
|
|
|
- c *MssqlConn
|
|
|
+type Result struct {
|
|
|
+ c *Conn
|
|
|
rowsAffected int64
|
|
|
}
|
|
|
|
|
|
-func (r *MssqlResult) RowsAffected() (int64, error) {
|
|
|
+func (r *Result) RowsAffected() (int64, error) {
|
|
|
return r.rowsAffected, nil
|
|
|
}
|
|
|
|
|
|
-func (r *MssqlResult) LastInsertId() (int64, error) {
|
|
|
+func (r *Result) LastInsertId() (int64, error) {
|
|
|
s, err := r.c.Prepare("select cast(@@identity as bigint)")
|
|
|
if err != nil {
|
|
|
return 0, err
|