|
@@ -2,7 +2,9 @@ package pq
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"bufio"
|
|
"bufio"
|
|
|
|
|
+ "context"
|
|
|
"crypto/md5"
|
|
"crypto/md5"
|
|
|
|
|
+ "crypto/sha256"
|
|
|
"database/sql"
|
|
"database/sql"
|
|
|
"database/sql/driver"
|
|
"database/sql/driver"
|
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
@@ -20,6 +22,7 @@ import (
|
|
|
"unicode"
|
|
"unicode"
|
|
|
|
|
|
|
|
"github.com/lib/pq/oid"
|
|
"github.com/lib/pq/oid"
|
|
|
|
|
+ "github.com/lib/pq/scram"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
// Common error types
|
|
// Common error types
|
|
@@ -89,13 +92,25 @@ type Dialer interface {
|
|
|
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
|
|
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-type defaultDialer struct{}
|
|
|
|
|
|
|
+// DialerContext is the context-aware dialer interface.
|
|
|
|
|
+type DialerContext interface {
|
|
|
|
|
+ DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type defaultDialer struct {
|
|
|
|
|
+ d net.Dialer
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
-func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
|
|
|
|
|
- return net.Dial(ntw, addr)
|
|
|
|
|
|
|
+func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
|
|
|
|
|
+ return d.d.Dial(network, address)
|
|
|
}
|
|
}
|
|
|
-func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
|
|
|
|
|
- return net.DialTimeout(ntw, addr, timeout)
|
|
|
|
|
|
|
+func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
|
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+ return d.DialContext(ctx, network, address)
|
|
|
|
|
+}
|
|
|
|
|
+func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
+ return d.d.DialContext(ctx, network, address)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
type conn struct {
|
|
type conn struct {
|
|
@@ -244,90 +259,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Open opens a new connection to the database. name is a connection string.
|
|
|
|
|
|
|
+// Open opens a new connection to the database. dsn is a connection string.
|
|
|
// Most users should only use it through database/sql package from the standard
|
|
// Most users should only use it through database/sql package from the standard
|
|
|
// library.
|
|
// library.
|
|
|
-func Open(name string) (_ driver.Conn, err error) {
|
|
|
|
|
- return DialOpen(defaultDialer{}, name)
|
|
|
|
|
|
|
+func Open(dsn string) (_ driver.Conn, err error) {
|
|
|
|
|
+ return DialOpen(defaultDialer{}, dsn)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// DialOpen opens a new connection to the database using a dialer.
|
|
// DialOpen opens a new connection to the database using a dialer.
|
|
|
-func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
|
|
|
|
|
|
|
+func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
|
|
|
|
|
+ c, err := NewConnector(dsn)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ c.dialer = d
|
|
|
|
|
+ return c.open(context.Background())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
|
|
|
// Handle any panics during connection initialization. Note that we
|
|
// Handle any panics during connection initialization. Note that we
|
|
|
// specifically do *not* want to use errRecover(), as that would turn any
|
|
// specifically do *not* want to use errRecover(), as that would turn any
|
|
|
// connection errors into ErrBadConns, hiding the real error message from
|
|
// connection errors into ErrBadConns, hiding the real error message from
|
|
|
// the user.
|
|
// the user.
|
|
|
defer errRecoverNoErrBadConn(&err)
|
|
defer errRecoverNoErrBadConn(&err)
|
|
|
|
|
|
|
|
- o := make(values)
|
|
|
|
|
-
|
|
|
|
|
- // A number of defaults are applied here, in this order:
|
|
|
|
|
- //
|
|
|
|
|
- // * Very low precedence defaults applied in every situation
|
|
|
|
|
- // * Environment variables
|
|
|
|
|
- // * Explicitly passed connection information
|
|
|
|
|
- o["host"] = "localhost"
|
|
|
|
|
- o["port"] = "5432"
|
|
|
|
|
- // N.B.: Extra float digits should be set to 3, but that breaks
|
|
|
|
|
- // Postgres 8.4 and older, where the max is 2.
|
|
|
|
|
- o["extra_float_digits"] = "2"
|
|
|
|
|
- for k, v := range parseEnviron(os.Environ()) {
|
|
|
|
|
- o[k] = v
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
|
|
|
|
|
- name, err = ParseURL(name)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if err := parseOpts(name, o); err != nil {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Use the "fallback" application name if necessary
|
|
|
|
|
- if fallback, ok := o["fallback_application_name"]; ok {
|
|
|
|
|
- if _, ok := o["application_name"]; !ok {
|
|
|
|
|
- o["application_name"] = fallback
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // We can't work with any client_encoding other than UTF-8 currently.
|
|
|
|
|
- // However, we have historically allowed the user to set it to UTF-8
|
|
|
|
|
- // explicitly, and there's no reason to break such programs, so allow that.
|
|
|
|
|
- // Note that the "options" setting could also set client_encoding, but
|
|
|
|
|
- // parsing its value is not worth it. Instead, we always explicitly send
|
|
|
|
|
- // client_encoding as a separate run-time parameter, which should override
|
|
|
|
|
- // anything set in options.
|
|
|
|
|
- if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
|
|
|
|
|
- return nil, errors.New("client_encoding must be absent or 'UTF8'")
|
|
|
|
|
- }
|
|
|
|
|
- o["client_encoding"] = "UTF8"
|
|
|
|
|
- // DateStyle needs a similar treatment.
|
|
|
|
|
- if datestyle, ok := o["datestyle"]; ok {
|
|
|
|
|
- if datestyle != "ISO, MDY" {
|
|
|
|
|
- panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
|
|
|
|
|
- "ISO, MDY", datestyle))
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- o["datestyle"] = "ISO, MDY"
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ o := c.opts
|
|
|
|
|
|
|
|
- // If a user is not provided by any other means, the last
|
|
|
|
|
- // resort is to use the current operating system provided user
|
|
|
|
|
- // name.
|
|
|
|
|
- if _, ok := o["user"]; !ok {
|
|
|
|
|
- u, err := userCurrent()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
- o["user"] = u
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- cn := &conn{
|
|
|
|
|
|
|
+ cn = &conn{
|
|
|
opts: o,
|
|
opts: o,
|
|
|
- dialer: d,
|
|
|
|
|
|
|
+ dialer: c.dialer,
|
|
|
}
|
|
}
|
|
|
err = cn.handleDriverSettings(o)
|
|
err = cn.handleDriverSettings(o)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -335,13 +295,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
|
|
|
}
|
|
}
|
|
|
cn.handlePgpass(o)
|
|
cn.handlePgpass(o)
|
|
|
|
|
|
|
|
- cn.c, err = dial(d, o)
|
|
|
|
|
|
|
+ cn.c, err = dial(ctx, c.dialer, o)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
err = cn.ssl(o)
|
|
err = cn.ssl(o)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
|
|
+ if cn.c != nil {
|
|
|
|
|
+ cn.c.Close()
|
|
|
|
|
+ }
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -364,10 +327,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
|
|
|
return cn, err
|
|
return cn, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func dial(d Dialer, o values) (net.Conn, error) {
|
|
|
|
|
- ntw, addr := network(o)
|
|
|
|
|
|
|
+func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
|
|
|
|
|
+ network, address := network(o)
|
|
|
// SSL is not necessary or supported over UNIX domain sockets
|
|
// SSL is not necessary or supported over UNIX domain sockets
|
|
|
- if ntw == "unix" {
|
|
|
|
|
|
|
+ if network == "unix" {
|
|
|
o["sslmode"] = "disable"
|
|
o["sslmode"] = "disable"
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -378,19 +341,30 @@ func dial(d Dialer, o values) (net.Conn, error) {
|
|
|
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
|
|
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
|
|
|
}
|
|
}
|
|
|
duration := time.Duration(seconds) * time.Second
|
|
duration := time.Duration(seconds) * time.Second
|
|
|
|
|
+
|
|
|
// connect_timeout should apply to the entire connection establishment
|
|
// connect_timeout should apply to the entire connection establishment
|
|
|
// procedure, so we both use a timeout for the TCP connection
|
|
// procedure, so we both use a timeout for the TCP connection
|
|
|
// establishment and set a deadline for doing the initial handshake.
|
|
// establishment and set a deadline for doing the initial handshake.
|
|
|
// The deadline is then reset after startup() is done.
|
|
// The deadline is then reset after startup() is done.
|
|
|
deadline := time.Now().Add(duration)
|
|
deadline := time.Now().Add(duration)
|
|
|
- conn, err := d.DialTimeout(ntw, addr, duration)
|
|
|
|
|
|
|
+ var conn net.Conn
|
|
|
|
|
+ if dctx, ok := d.(DialerContext); ok {
|
|
|
|
|
+ ctx, cancel := context.WithTimeout(ctx, duration)
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+ conn, err = dctx.DialContext(ctx, network, address)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ conn, err = d.DialTimeout(network, address, duration)
|
|
|
|
|
+ }
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
err = conn.SetDeadline(deadline)
|
|
err = conn.SetDeadline(deadline)
|
|
|
return conn, err
|
|
return conn, err
|
|
|
}
|
|
}
|
|
|
- return d.Dial(ntw, addr)
|
|
|
|
|
|
|
+ if dctx, ok := d.(DialerContext); ok {
|
|
|
|
|
+ return dctx.DialContext(ctx, network, address)
|
|
|
|
|
+ }
|
|
|
|
|
+ return d.Dial(network, address)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func network(o values) (string, string) {
|
|
func network(o values) (string, string) {
|
|
@@ -576,7 +550,7 @@ func (cn *conn) Commit() (err error) {
|
|
|
// would get the same behaviour if you issued a COMMIT in a failed
|
|
// would get the same behaviour if you issued a COMMIT in a failed
|
|
|
// transaction, so it's also the least surprising thing to do here.
|
|
// transaction, so it's also the least surprising thing to do here.
|
|
|
if cn.txnStatus == txnStatusInFailedTransaction {
|
|
if cn.txnStatus == txnStatusInFailedTransaction {
|
|
|
- if err := cn.Rollback(); err != nil {
|
|
|
|
|
|
|
+ if err := cn.rollback(); err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
return ErrInFailedTransaction
|
|
return ErrInFailedTransaction
|
|
@@ -603,7 +577,10 @@ func (cn *conn) Rollback() (err error) {
|
|
|
return driver.ErrBadConn
|
|
return driver.ErrBadConn
|
|
|
}
|
|
}
|
|
|
defer cn.errRecover(&err)
|
|
defer cn.errRecover(&err)
|
|
|
|
|
+ return cn.rollback()
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
|
|
+func (cn *conn) rollback() (err error) {
|
|
|
cn.checkIsInTransaction(true)
|
|
cn.checkIsInTransaction(true)
|
|
|
_, commandTag, err := cn.simpleExec("ROLLBACK")
|
|
_, commandTag, err := cn.simpleExec("ROLLBACK")
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -704,7 +681,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
|
|
|
// res might be non-nil here if we received a previous
|
|
// res might be non-nil here if we received a previous
|
|
|
// CommandComplete, but that's fine; just overwrite it
|
|
// CommandComplete, but that's fine; just overwrite it
|
|
|
res = &rows{cn: cn}
|
|
res = &rows{cn: cn}
|
|
|
- res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
|
|
|
|
|
|
|
+ res.rowsHeader = parsePortalRowDescribe(r)
|
|
|
|
|
|
|
|
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
|
|
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
|
|
|
// until the first DataRow has been received.
|
|
// until the first DataRow has been received.
|
|
@@ -861,17 +838,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
|
|
|
cn.readParseResponse()
|
|
cn.readParseResponse()
|
|
|
cn.readBindResponse()
|
|
cn.readBindResponse()
|
|
|
rows := &rows{cn: cn}
|
|
rows := &rows{cn: cn}
|
|
|
- rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
|
|
|
|
|
|
|
+ rows.rowsHeader = cn.readPortalDescribeResponse()
|
|
|
cn.postExecuteWorkaround()
|
|
cn.postExecuteWorkaround()
|
|
|
return rows, nil
|
|
return rows, nil
|
|
|
}
|
|
}
|
|
|
st := cn.prepareTo(query, "")
|
|
st := cn.prepareTo(query, "")
|
|
|
st.exec(args)
|
|
st.exec(args)
|
|
|
return &rows{
|
|
return &rows{
|
|
|
- cn: cn,
|
|
|
|
|
- colNames: st.colNames,
|
|
|
|
|
- colTyps: st.colTyps,
|
|
|
|
|
- colFmts: st.colFmts,
|
|
|
|
|
|
|
+ cn: cn,
|
|
|
|
|
+ rowsHeader: st.rowsHeader,
|
|
|
}, nil
|
|
}, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -992,7 +967,6 @@ func (cn *conn) recv() (t byte, r *readBuf) {
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
panic(err)
|
|
panic(err)
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
switch t {
|
|
switch t {
|
|
|
case 'E':
|
|
case 'E':
|
|
|
panic(parseError(r))
|
|
panic(parseError(r))
|
|
@@ -1163,6 +1137,55 @@ func (cn *conn) auth(r *readBuf, o values) {
|
|
|
if r.int32() != 0 {
|
|
if r.int32() != 0 {
|
|
|
errorf("unexpected authentication response: %q", t)
|
|
errorf("unexpected authentication response: %q", t)
|
|
|
}
|
|
}
|
|
|
|
|
+ case 10:
|
|
|
|
|
+ sc := scram.NewClient(sha256.New, o["user"], o["password"])
|
|
|
|
|
+ sc.Step(nil)
|
|
|
|
|
+ if sc.Err() != nil {
|
|
|
|
|
+ errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
|
|
|
|
|
+ }
|
|
|
|
|
+ scOut := sc.Out()
|
|
|
|
|
+
|
|
|
|
|
+ w := cn.writeBuf('p')
|
|
|
|
|
+ w.string("SCRAM-SHA-256")
|
|
|
|
|
+ w.int32(len(scOut))
|
|
|
|
|
+ w.bytes(scOut)
|
|
|
|
|
+ cn.send(w)
|
|
|
|
|
+
|
|
|
|
|
+ t, r := cn.recv()
|
|
|
|
|
+ if t != 'R' {
|
|
|
|
|
+ errorf("unexpected password response: %q", t)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if r.int32() != 11 {
|
|
|
|
|
+ errorf("unexpected authentication response: %q", t)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ nextStep := r.next(len(*r))
|
|
|
|
|
+ sc.Step(nextStep)
|
|
|
|
|
+ if sc.Err() != nil {
|
|
|
|
|
+ errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ scOut = sc.Out()
|
|
|
|
|
+ w = cn.writeBuf('p')
|
|
|
|
|
+ w.bytes(scOut)
|
|
|
|
|
+ cn.send(w)
|
|
|
|
|
+
|
|
|
|
|
+ t, r = cn.recv()
|
|
|
|
|
+ if t != 'R' {
|
|
|
|
|
+ errorf("unexpected password response: %q", t)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if r.int32() != 12 {
|
|
|
|
|
+ errorf("unexpected authentication response: %q", t)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ nextStep = r.next(len(*r))
|
|
|
|
|
+ sc.Step(nextStep)
|
|
|
|
|
+ if sc.Err() != nil {
|
|
|
|
|
+ errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
default:
|
|
default:
|
|
|
errorf("unknown authentication response: %d", code)
|
|
errorf("unknown authentication response: %d", code)
|
|
|
}
|
|
}
|
|
@@ -1180,12 +1203,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1}
|
|
|
var colFmtDataAllText = []byte{0, 0}
|
|
var colFmtDataAllText = []byte{0, 0}
|
|
|
|
|
|
|
|
type stmt struct {
|
|
type stmt struct {
|
|
|
- cn *conn
|
|
|
|
|
- name string
|
|
|
|
|
- colNames []string
|
|
|
|
|
- colFmts []format
|
|
|
|
|
|
|
+ cn *conn
|
|
|
|
|
+ name string
|
|
|
|
|
+ rowsHeader
|
|
|
colFmtData []byte
|
|
colFmtData []byte
|
|
|
- colTyps []fieldDesc
|
|
|
|
|
paramTyps []oid.Oid
|
|
paramTyps []oid.Oid
|
|
|
closed bool
|
|
closed bool
|
|
|
}
|
|
}
|
|
@@ -1231,10 +1252,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
|
|
|
|
|
|
|
|
st.exec(v)
|
|
st.exec(v)
|
|
|
return &rows{
|
|
return &rows{
|
|
|
- cn: st.cn,
|
|
|
|
|
- colNames: st.colNames,
|
|
|
|
|
- colTyps: st.colTyps,
|
|
|
|
|
- colFmts: st.colFmts,
|
|
|
|
|
|
|
+ cn: st.cn,
|
|
|
|
|
+ rowsHeader: st.rowsHeader,
|
|
|
}, nil
|
|
}, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1344,16 +1363,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
|
|
|
return driver.RowsAffected(n), commandTag
|
|
return driver.RowsAffected(n), commandTag
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-type rows struct {
|
|
|
|
|
- cn *conn
|
|
|
|
|
- finish func()
|
|
|
|
|
|
|
+type rowsHeader struct {
|
|
|
colNames []string
|
|
colNames []string
|
|
|
colTyps []fieldDesc
|
|
colTyps []fieldDesc
|
|
|
colFmts []format
|
|
colFmts []format
|
|
|
- done bool
|
|
|
|
|
- rb readBuf
|
|
|
|
|
- result driver.Result
|
|
|
|
|
- tag string
|
|
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type rows struct {
|
|
|
|
|
+ cn *conn
|
|
|
|
|
+ finish func()
|
|
|
|
|
+ rowsHeader
|
|
|
|
|
+ done bool
|
|
|
|
|
+ rb readBuf
|
|
|
|
|
+ result driver.Result
|
|
|
|
|
+ tag string
|
|
|
|
|
+
|
|
|
|
|
+ next *rowsHeader
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (rs *rows) Close() error {
|
|
func (rs *rows) Close() error {
|
|
@@ -1440,7 +1465,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
|
|
|
}
|
|
}
|
|
|
return
|
|
return
|
|
|
case 'T':
|
|
case 'T':
|
|
|
- rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
|
|
|
|
|
|
|
+ next := parsePortalRowDescribe(&rs.rb)
|
|
|
|
|
+ rs.next = &next
|
|
|
return io.EOF
|
|
return io.EOF
|
|
|
default:
|
|
default:
|
|
|
errorf("unexpected message after execute: %q", t)
|
|
errorf("unexpected message after execute: %q", t)
|
|
@@ -1449,10 +1475,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (rs *rows) HasNextResultSet() bool {
|
|
func (rs *rows) HasNextResultSet() bool {
|
|
|
- return !rs.done
|
|
|
|
|
|
|
+ hasNext := rs.next != nil && !rs.done
|
|
|
|
|
+ return hasNext
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (rs *rows) NextResultSet() error {
|
|
func (rs *rows) NextResultSet() error {
|
|
|
|
|
+ if rs.next == nil {
|
|
|
|
|
+ return io.EOF
|
|
|
|
|
+ }
|
|
|
|
|
+ rs.rowsHeader = *rs.next
|
|
|
|
|
+ rs.next = nil
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1475,6 +1507,39 @@ func QuoteIdentifier(name string) string {
|
|
|
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
|
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
|
|
|
|
|
+// to DDL and other statements that do not accept parameters) to be used as part
|
|
|
|
|
+// of an SQL statement. For example:
|
|
|
|
|
+//
|
|
|
|
|
+// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
|
|
|
|
|
+// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
|
|
|
|
|
+//
|
|
|
|
|
+// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
|
|
|
|
|
+// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
|
|
|
|
|
+// that PostgreSQL provides ('E') will be prepended to the string.
|
|
|
|
|
+func QuoteLiteral(literal string) string {
|
|
|
|
|
+ // This follows the PostgreSQL internal algorithm for handling quoted literals
|
|
|
|
|
+ // from libpq, which can be found in the "PQEscapeStringInternal" function,
|
|
|
|
|
+ // which is found in the libpq/fe-exec.c source file:
|
|
|
|
|
+ // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
|
|
|
|
|
+ //
|
|
|
|
|
+ // substitute any single-quotes (') with two single-quotes ('')
|
|
|
|
|
+ literal = strings.Replace(literal, `'`, `''`, -1)
|
|
|
|
|
+ // determine if the string has any backslashes (\) in it.
|
|
|
|
|
+ // if it does, replace any backslashes (\) with two backslashes (\\)
|
|
|
|
|
+ // then, we need to wrap the entire string with a PostgreSQL
|
|
|
|
|
+ // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
|
|
|
|
|
+ // also add a space before the "E"
|
|
|
|
|
+ if strings.Contains(literal, `\`) {
|
|
|
|
|
+ literal = strings.Replace(literal, `\`, `\\`, -1)
|
|
|
|
|
+ literal = ` E'` + literal + `'`
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // otherwise, we can just wrap the literal with a pair of single quotes
|
|
|
|
|
+ literal = `'` + literal + `'`
|
|
|
|
|
+ }
|
|
|
|
|
+ return literal
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func md5s(s string) string {
|
|
func md5s(s string) string {
|
|
|
h := md5.New()
|
|
h := md5.New()
|
|
|
h.Write([]byte(s))
|
|
h.Write([]byte(s))
|
|
@@ -1630,13 +1695,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
|
|
|
|
|
|
|
+func (cn *conn) readPortalDescribeResponse() rowsHeader {
|
|
|
t, r := cn.recv1()
|
|
t, r := cn.recv1()
|
|
|
switch t {
|
|
switch t {
|
|
|
case 'T':
|
|
case 'T':
|
|
|
return parsePortalRowDescribe(r)
|
|
return parsePortalRowDescribe(r)
|
|
|
case 'n':
|
|
case 'n':
|
|
|
- return nil, nil, nil
|
|
|
|
|
|
|
+ return rowsHeader{}
|
|
|
case 'E':
|
|
case 'E':
|
|
|
err := parseError(r)
|
|
err := parseError(r)
|
|
|
cn.readReadyForQuery()
|
|
cn.readReadyForQuery()
|
|
@@ -1742,11 +1807,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
|
|
|
|
|
|
|
+func parsePortalRowDescribe(r *readBuf) rowsHeader {
|
|
|
n := r.int16()
|
|
n := r.int16()
|
|
|
- colNames = make([]string, n)
|
|
|
|
|
- colFmts = make([]format, n)
|
|
|
|
|
- colTyps = make([]fieldDesc, n)
|
|
|
|
|
|
|
+ colNames := make([]string, n)
|
|
|
|
|
+ colFmts := make([]format, n)
|
|
|
|
|
+ colTyps := make([]fieldDesc, n)
|
|
|
for i := range colNames {
|
|
for i := range colNames {
|
|
|
colNames[i] = r.string()
|
|
colNames[i] = r.string()
|
|
|
r.next(6)
|
|
r.next(6)
|
|
@@ -1755,7 +1820,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co
|
|
|
colTyps[i].Mod = r.int32()
|
|
colTyps[i].Mod = r.int32()
|
|
|
colFmts[i] = format(r.int16())
|
|
colFmts[i] = format(r.int16())
|
|
|
}
|
|
}
|
|
|
- return
|
|
|
|
|
|
|
+ return rowsHeader{
|
|
|
|
|
+ colNames: colNames,
|
|
|
|
|
+ colFmts: colFmts,
|
|
|
|
|
+ colTyps: colTyps,
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// parseEnviron tries to mimic some of libpq's environment handling
|
|
// parseEnviron tries to mimic some of libpq's environment handling
|