conn.go 43 KB


  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "database/sql"
  6. "database/sql/driver"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net"
  12. "os"
  13. "os/user"
  14. "path"
  15. "path/filepath"
  16. "strconv"
  17. "strings"
  18. "time"
  19. "unicode"
  20. "github.com/lib/pq/oid"
  21. )
  22. // Common error types
  23. var (
  24. ErrNotSupported = errors.New("pq: Unsupported command")
  25. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  26. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  27. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  28. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  29. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  30. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  31. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  32. )
  33. // Driver is the Postgres database driver.
  34. type Driver struct{}
  35. // Open opens a new connection to the database. name is a connection string.
  36. // Most users should only use it through database/sql package from the standard
  37. // library.
  38. func (d *Driver) Open(name string) (driver.Conn, error) {
  39. return Open(name)
  40. }
  41. func init() {
  42. sql.Register("postgres", &Driver{})
  43. }
  44. type parameterStatus struct {
  45. // server version in the same format as server_version_num, or 0 if
  46. // unavailable
  47. serverVersion int
  48. // the current location based on the TimeZone value of the session, if
  49. // available
  50. currentLocation *time.Location
  51. }
  52. type transactionStatus byte
  53. const (
  54. txnStatusIdle transactionStatus = 'I'
  55. txnStatusIdleInTransaction transactionStatus = 'T'
  56. txnStatusInFailedTransaction transactionStatus = 'E'
  57. )
  58. func (s transactionStatus) String() string {
  59. switch s {
  60. case txnStatusIdle:
  61. return "idle"
  62. case txnStatusIdleInTransaction:
  63. return "idle in transaction"
  64. case txnStatusInFailedTransaction:
  65. return "in a failed transaction"
  66. default:
  67. errorf("unknown transactionStatus %d", s)
  68. }
  69. panic("not reached")
  70. }
  71. // Dialer is the dialer interface. It can be used to obtain more control over
  72. // how pq creates network connections.
  73. type Dialer interface {
  74. Dial(network, address string) (net.Conn, error)
  75. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  76. }
  77. type defaultDialer struct{}
  78. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  79. return net.Dial(ntw, addr)
  80. }
  81. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  82. return net.DialTimeout(ntw, addr, timeout)
  83. }
  84. type conn struct {
  85. c net.Conn
  86. buf *bufio.Reader
  87. namei int
  88. scratch [512]byte
  89. txnStatus transactionStatus
  90. txnFinish func()
  91. // Save connection arguments to use during CancelRequest.
  92. dialer Dialer
  93. opts values
  94. // Cancellation key data for use with CancelRequest messages.
  95. processID int
  96. secretKey int
  97. parameterStatus parameterStatus
  98. saveMessageType byte
  99. saveMessageBuffer []byte
  100. // If true, this connection is bad and all public-facing functions should
  101. // return ErrBadConn.
  102. bad bool
  103. // If set, this connection should never use the binary format when
  104. // receiving query results from prepared statements. Only provided for
  105. // debugging.
  106. disablePreparedBinaryResult bool
  107. // Whether to always send []byte parameters over as binary. Enables single
  108. // round-trip mode for non-prepared Query calls.
  109. binaryParameters bool
  110. // If true this connection is in the middle of a COPY
  111. inCopy bool
  112. }
  113. // Handle driver-side settings in parsed connection string.
  114. func (cn *conn) handleDriverSettings(o values) (err error) {
  115. boolSetting := func(key string, val *bool) error {
  116. if value, ok := o[key]; ok {
  117. if value == "yes" {
  118. *val = true
  119. } else if value == "no" {
  120. *val = false
  121. } else {
  122. return fmt.Errorf("unrecognized value %q for %s", value, key)
  123. }
  124. }
  125. return nil
  126. }
  127. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  128. if err != nil {
  129. return err
  130. }
  131. return boolSetting("binary_parameters", &cn.binaryParameters)
  132. }
  133. func (cn *conn) handlePgpass(o values) {
  134. // if a password was supplied, do not process .pgpass
  135. if _, ok := o["password"]; ok {
  136. return
  137. }
  138. filename := os.Getenv("PGPASSFILE")
  139. if filename == "" {
  140. // XXX this code doesn't work on Windows where the default filename is
  141. // XXX %APPDATA%\postgresql\pgpass.conf
  142. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  143. userHome := os.Getenv("HOME")
  144. if userHome == "" {
  145. user, err := user.Current()
  146. if err != nil {
  147. return
  148. }
  149. userHome = user.HomeDir
  150. }
  151. filename = filepath.Join(userHome, ".pgpass")
  152. }
  153. fileinfo, err := os.Stat(filename)
  154. if err != nil {
  155. return
  156. }
  157. mode := fileinfo.Mode()
  158. if mode&(0x77) != 0 {
  159. // XXX should warn about incorrect .pgpass permissions as psql does
  160. return
  161. }
  162. file, err := os.Open(filename)
  163. if err != nil {
  164. return
  165. }
  166. defer file.Close()
  167. scanner := bufio.NewScanner(io.Reader(file))
  168. hostname := o["host"]
  169. ntw, _ := network(o)
  170. port := o["port"]
  171. db := o["dbname"]
  172. username := o["user"]
  173. // From: https://github.com/tg/pgpass/blob/master/reader.go
  174. getFields := func(s string) []string {
  175. fs := make([]string, 0, 5)
  176. f := make([]rune, 0, len(s))
  177. var esc bool
  178. for _, c := range s {
  179. switch {
  180. case esc:
  181. f = append(f, c)
  182. esc = false
  183. case c == '\\':
  184. esc = true
  185. case c == ':':
  186. fs = append(fs, string(f))
  187. f = f[:0]
  188. default:
  189. f = append(f, c)
  190. }
  191. }
  192. return append(fs, string(f))
  193. }
  194. for scanner.Scan() {
  195. line := scanner.Text()
  196. if len(line) == 0 || line[0] == '#' {
  197. continue
  198. }
  199. split := getFields(line)
  200. if len(split) != 5 {
  201. continue
  202. }
  203. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  204. o["password"] = split[4]
  205. return
  206. }
  207. }
  208. }
  209. func (cn *conn) writeBuf(b byte) *writeBuf {
  210. cn.scratch[0] = b
  211. return &writeBuf{
  212. buf: cn.scratch[:5],
  213. pos: 1,
  214. }
  215. }
  216. // Open opens a new connection to the database. name is a connection string.
  217. // Most users should only use it through database/sql package from the standard
  218. // library.
  219. func Open(name string) (_ driver.Conn, err error) {
  220. return DialOpen(defaultDialer{}, name)
  221. }
  222. // DialOpen opens a new connection to the database using a dialer.
  223. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  224. // Handle any panics during connection initialization. Note that we
  225. // specifically do *not* want to use errRecover(), as that would turn any
  226. // connection errors into ErrBadConns, hiding the real error message from
  227. // the user.
  228. defer errRecoverNoErrBadConn(&err)
  229. o := make(values)
  230. // A number of defaults are applied here, in this order:
  231. //
  232. // * Very low precedence defaults applied in every situation
  233. // * Environment variables
  234. // * Explicitly passed connection information
  235. o["host"] = "localhost"
  236. o["port"] = "5432"
  237. // N.B.: Extra float digits should be set to 3, but that breaks
  238. // Postgres 8.4 and older, where the max is 2.
  239. o["extra_float_digits"] = "2"
  240. for k, v := range parseEnviron(os.Environ()) {
  241. o[k] = v
  242. }
  243. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  244. name, err = ParseURL(name)
  245. if err != nil {
  246. return nil, err
  247. }
  248. }
  249. if err := parseOpts(name, o); err != nil {
  250. return nil, err
  251. }
  252. // Use the "fallback" application name if necessary
  253. if fallback, ok := o["fallback_application_name"]; ok {
  254. if _, ok := o["application_name"]; !ok {
  255. o["application_name"] = fallback
  256. }
  257. }
  258. // We can't work with any client_encoding other than UTF-8 currently.
  259. // However, we have historically allowed the user to set it to UTF-8
  260. // explicitly, and there's no reason to break such programs, so allow that.
  261. // Note that the "options" setting could also set client_encoding, but
  262. // parsing its value is not worth it. Instead, we always explicitly send
  263. // client_encoding as a separate run-time parameter, which should override
  264. // anything set in options.
  265. if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
  266. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  267. }
  268. o["client_encoding"] = "UTF8"
  269. // DateStyle needs a similar treatment.
  270. if datestyle, ok := o["datestyle"]; ok {
  271. if datestyle != "ISO, MDY" {
  272. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  273. "ISO, MDY", datestyle))
  274. }
  275. } else {
  276. o["datestyle"] = "ISO, MDY"
  277. }
  278. // If a user is not provided by any other means, the last
  279. // resort is to use the current operating system provided user
  280. // name.
  281. if _, ok := o["user"]; !ok {
  282. u, err := userCurrent()
  283. if err != nil {
  284. return nil, err
  285. }
  286. o["user"] = u
  287. }
  288. cn := &conn{
  289. opts: o,
  290. dialer: d,
  291. }
  292. err = cn.handleDriverSettings(o)
  293. if err != nil {
  294. return nil, err
  295. }
  296. cn.handlePgpass(o)
  297. cn.c, err = dial(d, o)
  298. if err != nil {
  299. return nil, err
  300. }
  301. cn.ssl(o)
  302. cn.buf = bufio.NewReader(cn.c)
  303. cn.startup(o)
  304. // reset the deadline, in case one was set (see dial)
  305. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  306. err = cn.c.SetDeadline(time.Time{})
  307. }
  308. return cn, err
  309. }
  310. func dial(d Dialer, o values) (net.Conn, error) {
  311. ntw, addr := network(o)
  312. // SSL is not necessary or supported over UNIX domain sockets
  313. if ntw == "unix" {
  314. o["sslmode"] = "disable"
  315. }
  316. // Zero or not specified means wait indefinitely.
  317. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  318. seconds, err := strconv.ParseInt(timeout, 10, 0)
  319. if err != nil {
  320. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  321. }
  322. duration := time.Duration(seconds) * time.Second
  323. // connect_timeout should apply to the entire connection establishment
  324. // procedure, so we both use a timeout for the TCP connection
  325. // establishment and set a deadline for doing the initial handshake.
  326. // The deadline is then reset after startup() is done.
  327. deadline := time.Now().Add(duration)
  328. conn, err := d.DialTimeout(ntw, addr, duration)
  329. if err != nil {
  330. return nil, err
  331. }
  332. err = conn.SetDeadline(deadline)
  333. return conn, err
  334. }
  335. return d.Dial(ntw, addr)
  336. }
  337. func network(o values) (string, string) {
  338. host := o["host"]
  339. if strings.HasPrefix(host, "/") {
  340. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  341. return "unix", sockPath
  342. }
  343. return "tcp", net.JoinHostPort(host, o["port"])
  344. }
  345. type values map[string]string
  346. // scanner implements a tokenizer for libpq-style option strings.
  347. type scanner struct {
  348. s []rune
  349. i int
  350. }
  351. // newScanner returns a new scanner initialized with the option string s.
  352. func newScanner(s string) *scanner {
  353. return &scanner{[]rune(s), 0}
  354. }
  355. // Next returns the next rune.
  356. // It returns 0, false if the end of the text has been reached.
  357. func (s *scanner) Next() (rune, bool) {
  358. if s.i >= len(s.s) {
  359. return 0, false
  360. }
  361. r := s.s[s.i]
  362. s.i++
  363. return r, true
  364. }
  365. // SkipSpaces returns the next non-whitespace rune.
  366. // It returns 0, false if the end of the text has been reached.
  367. func (s *scanner) SkipSpaces() (rune, bool) {
  368. r, ok := s.Next()
  369. for unicode.IsSpace(r) && ok {
  370. r, ok = s.Next()
  371. }
  372. return r, ok
  373. }
  374. // parseOpts parses the options from name and adds them to the values.
  375. //
  376. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  377. func parseOpts(name string, o values) error {
  378. s := newScanner(name)
  379. for {
  380. var (
  381. keyRunes, valRunes []rune
  382. r rune
  383. ok bool
  384. )
  385. if r, ok = s.SkipSpaces(); !ok {
  386. break
  387. }
  388. // Scan the key
  389. for !unicode.IsSpace(r) && r != '=' {
  390. keyRunes = append(keyRunes, r)
  391. if r, ok = s.Next(); !ok {
  392. break
  393. }
  394. }
  395. // Skip any whitespace if we're not at the = yet
  396. if r != '=' {
  397. r, ok = s.SkipSpaces()
  398. }
  399. // The current character should be =
  400. if r != '=' || !ok {
  401. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  402. }
  403. // Skip any whitespace after the =
  404. if r, ok = s.SkipSpaces(); !ok {
  405. // If we reach the end here, the last value is just an empty string as per libpq.
  406. o[string(keyRunes)] = ""
  407. break
  408. }
  409. if r != '\'' {
  410. for !unicode.IsSpace(r) {
  411. if r == '\\' {
  412. if r, ok = s.Next(); !ok {
  413. return fmt.Errorf(`missing character after backslash`)
  414. }
  415. }
  416. valRunes = append(valRunes, r)
  417. if r, ok = s.Next(); !ok {
  418. break
  419. }
  420. }
  421. } else {
  422. quote:
  423. for {
  424. if r, ok = s.Next(); !ok {
  425. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  426. }
  427. switch r {
  428. case '\'':
  429. break quote
  430. case '\\':
  431. r, _ = s.Next()
  432. fallthrough
  433. default:
  434. valRunes = append(valRunes, r)
  435. }
  436. }
  437. }
  438. o[string(keyRunes)] = string(valRunes)
  439. }
  440. return nil
  441. }
  442. func (cn *conn) isInTransaction() bool {
  443. return cn.txnStatus == txnStatusIdleInTransaction ||
  444. cn.txnStatus == txnStatusInFailedTransaction
  445. }
  446. func (cn *conn) checkIsInTransaction(intxn bool) {
  447. if cn.isInTransaction() != intxn {
  448. cn.bad = true
  449. errorf("unexpected transaction status %v", cn.txnStatus)
  450. }
  451. }
  452. func (cn *conn) Begin() (_ driver.Tx, err error) {
  453. return cn.begin("")
  454. }
  455. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  456. if cn.bad {
  457. return nil, driver.ErrBadConn
  458. }
  459. defer cn.errRecover(&err)
  460. cn.checkIsInTransaction(false)
  461. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  462. if err != nil {
  463. return nil, err
  464. }
  465. if commandTag != "BEGIN" {
  466. cn.bad = true
  467. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  468. }
  469. if cn.txnStatus != txnStatusIdleInTransaction {
  470. cn.bad = true
  471. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  472. }
  473. return cn, nil
  474. }
  475. func (cn *conn) closeTxn() {
  476. if finish := cn.txnFinish; finish != nil {
  477. finish()
  478. }
  479. }
  480. func (cn *conn) Commit() (err error) {
  481. defer cn.closeTxn()
  482. if cn.bad {
  483. return driver.ErrBadConn
  484. }
  485. defer cn.errRecover(&err)
  486. cn.checkIsInTransaction(true)
  487. // We don't want the client to think that everything is okay if it tries
  488. // to commit a failed transaction. However, no matter what we return,
  489. // database/sql will release this connection back into the free connection
  490. // pool so we have to abort the current transaction here. Note that you
  491. // would get the same behaviour if you issued a COMMIT in a failed
  492. // transaction, so it's also the least surprising thing to do here.
  493. if cn.txnStatus == txnStatusInFailedTransaction {
  494. if err := cn.Rollback(); err != nil {
  495. return err
  496. }
  497. return ErrInFailedTransaction
  498. }
  499. _, commandTag, err := cn.simpleExec("COMMIT")
  500. if err != nil {
  501. if cn.isInTransaction() {
  502. cn.bad = true
  503. }
  504. return err
  505. }
  506. if commandTag != "COMMIT" {
  507. cn.bad = true
  508. return fmt.Errorf("unexpected command tag %s", commandTag)
  509. }
  510. cn.checkIsInTransaction(false)
  511. return nil
  512. }
  513. func (cn *conn) Rollback() (err error) {
  514. defer cn.closeTxn()
  515. if cn.bad {
  516. return driver.ErrBadConn
  517. }
  518. defer cn.errRecover(&err)
  519. cn.checkIsInTransaction(true)
  520. _, commandTag, err := cn.simpleExec("ROLLBACK")
  521. if err != nil {
  522. if cn.isInTransaction() {
  523. cn.bad = true
  524. }
  525. return err
  526. }
  527. if commandTag != "ROLLBACK" {
  528. return fmt.Errorf("unexpected command tag %s", commandTag)
  529. }
  530. cn.checkIsInTransaction(false)
  531. return nil
  532. }
  533. func (cn *conn) gname() string {
  534. cn.namei++
  535. return strconv.FormatInt(int64(cn.namei), 10)
  536. }
  537. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  538. b := cn.writeBuf('Q')
  539. b.string(q)
  540. cn.send(b)
  541. for {
  542. t, r := cn.recv1()
  543. switch t {
  544. case 'C':
  545. res, commandTag = cn.parseComplete(r.string())
  546. case 'Z':
  547. cn.processReadyForQuery(r)
  548. if res == nil && err == nil {
  549. err = errUnexpectedReady
  550. }
  551. // done
  552. return
  553. case 'E':
  554. err = parseError(r)
  555. case 'I':
  556. res = emptyRows
  557. case 'T', 'D':
  558. // ignore any results
  559. default:
  560. cn.bad = true
  561. errorf("unknown response for simple query: %q", t)
  562. }
  563. }
  564. }
  565. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  566. defer cn.errRecover(&err)
  567. b := cn.writeBuf('Q')
  568. b.string(q)
  569. cn.send(b)
  570. for {
  571. t, r := cn.recv1()
  572. switch t {
  573. case 'C', 'I':
  574. // We allow queries which don't return any results through Query as
  575. // well as Exec. We still have to give database/sql a rows object
  576. // the user can close, though, to avoid connections from being
  577. // leaked. A "rows" with done=true works fine for that purpose.
  578. if err != nil {
  579. cn.bad = true
  580. errorf("unexpected message %q in simple query execution", t)
  581. }
  582. if res == nil {
  583. res = &rows{
  584. cn: cn,
  585. }
  586. }
  587. // Set the result and tag to the last command complete if there wasn't a
  588. // query already run. Although queries usually return from here and cede
  589. // control to Next, a query with zero results does not.
  590. if t == 'C' && res.colNames == nil {
  591. res.result, res.tag = cn.parseComplete(r.string())
  592. }
  593. res.done = true
  594. case 'Z':
  595. cn.processReadyForQuery(r)
  596. // done
  597. return
  598. case 'E':
  599. res = nil
  600. err = parseError(r)
  601. case 'D':
  602. if res == nil {
  603. cn.bad = true
  604. errorf("unexpected DataRow in simple query execution")
  605. }
  606. // the query didn't fail; kick off to Next
  607. cn.saveMessage(t, r)
  608. return
  609. case 'T':
  610. // res might be non-nil here if we received a previous
  611. // CommandComplete, but that's fine; just overwrite it
  612. res = &rows{cn: cn}
  613. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  614. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  615. // until the first DataRow has been received.
  616. default:
  617. cn.bad = true
  618. errorf("unknown response for simple query: %q", t)
  619. }
  620. }
  621. }
  622. type noRows struct{}
  623. var emptyRows noRows
  624. var _ driver.Result = noRows{}
  625. func (noRows) LastInsertId() (int64, error) {
  626. return 0, errNoLastInsertID
  627. }
  628. func (noRows) RowsAffected() (int64, error) {
  629. return 0, errNoRowsAffected
  630. }
  631. // Decides which column formats to use for a prepared statement. The input is
  632. // an array of type oids, one element per result column.
  633. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  634. if len(colTyps) == 0 {
  635. return nil, colFmtDataAllText
  636. }
  637. colFmts = make([]format, len(colTyps))
  638. if forceText {
  639. return colFmts, colFmtDataAllText
  640. }
  641. allBinary := true
  642. allText := true
  643. for i, t := range colTyps {
  644. switch t.OID {
  645. // This is the list of types to use binary mode for when receiving them
  646. // through a prepared statement. If a type appears in this list, it
  647. // must also be implemented in binaryDecode in encode.go.
  648. case oid.T_bytea:
  649. fallthrough
  650. case oid.T_int8:
  651. fallthrough
  652. case oid.T_int4:
  653. fallthrough
  654. case oid.T_int2:
  655. fallthrough
  656. case oid.T_uuid:
  657. colFmts[i] = formatBinary
  658. allText = false
  659. default:
  660. allBinary = false
  661. }
  662. }
  663. if allBinary {
  664. return colFmts, colFmtDataAllBinary
  665. } else if allText {
  666. return colFmts, colFmtDataAllText
  667. } else {
  668. colFmtData = make([]byte, 2+len(colFmts)*2)
  669. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  670. for i, v := range colFmts {
  671. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  672. }
  673. return colFmts, colFmtData
  674. }
  675. }
  676. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  677. st := &stmt{cn: cn, name: stmtName}
  678. b := cn.writeBuf('P')
  679. b.string(st.name)
  680. b.string(q)
  681. b.int16(0)
  682. b.next('D')
  683. b.byte('S')
  684. b.string(st.name)
  685. b.next('S')
  686. cn.send(b)
  687. cn.readParseResponse()
  688. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  689. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  690. cn.readReadyForQuery()
  691. return st
  692. }
  693. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  694. if cn.bad {
  695. return nil, driver.ErrBadConn
  696. }
  697. defer cn.errRecover(&err)
  698. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  699. s, err := cn.prepareCopyIn(q)
  700. if err == nil {
  701. cn.inCopy = true
  702. }
  703. return s, err
  704. }
  705. return cn.prepareTo(q, cn.gname()), nil
  706. }
  707. func (cn *conn) Close() (err error) {
  708. // Skip cn.bad return here because we always want to close a connection.
  709. defer cn.errRecover(&err)
  710. // Ensure that cn.c.Close is always run. Since error handling is done with
  711. // panics and cn.errRecover, the Close must be in a defer.
  712. defer func() {
  713. cerr := cn.c.Close()
  714. if err == nil {
  715. err = cerr
  716. }
  717. }()
  718. // Don't go through send(); ListenerConn relies on us not scribbling on the
  719. // scratch buffer of this connection.
  720. return cn.sendSimpleMessage('X')
  721. }
  722. // Implement the "Queryer" interface
  723. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  724. return cn.query(query, args)
  725. }
  726. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  727. if cn.bad {
  728. return nil, driver.ErrBadConn
  729. }
  730. if cn.inCopy {
  731. return nil, errCopyInProgress
  732. }
  733. defer cn.errRecover(&err)
  734. // Check to see if we can use the "simpleQuery" interface, which is
  735. // *much* faster than going through prepare/exec
  736. if len(args) == 0 {
  737. return cn.simpleQuery(query)
  738. }
  739. if cn.binaryParameters {
  740. cn.sendBinaryModeQuery(query, args)
  741. cn.readParseResponse()
  742. cn.readBindResponse()
  743. rows := &rows{cn: cn}
  744. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  745. cn.postExecuteWorkaround()
  746. return rows, nil
  747. }
  748. st := cn.prepareTo(query, "")
  749. st.exec(args)
  750. return &rows{
  751. cn: cn,
  752. colNames: st.colNames,
  753. colTyps: st.colTyps,
  754. colFmts: st.colFmts,
  755. }, nil
  756. }
  757. // Implement the optional "Execer" interface for one-shot queries
  758. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  759. if cn.bad {
  760. return nil, driver.ErrBadConn
  761. }
  762. defer cn.errRecover(&err)
  763. // Check to see if we can use the "simpleExec" interface, which is
  764. // *much* faster than going through prepare/exec
  765. if len(args) == 0 {
  766. // ignore commandTag, our caller doesn't care
  767. r, _, err := cn.simpleExec(query)
  768. return r, err
  769. }
  770. if cn.binaryParameters {
  771. cn.sendBinaryModeQuery(query, args)
  772. cn.readParseResponse()
  773. cn.readBindResponse()
  774. cn.readPortalDescribeResponse()
  775. cn.postExecuteWorkaround()
  776. res, _, err = cn.readExecuteResponse("Execute")
  777. return res, err
  778. }
  779. // Use the unnamed statement to defer planning until bind
  780. // time, or else value-based selectivity estimates cannot be
  781. // used.
  782. st := cn.prepareTo(query, "")
  783. r, err := st.Exec(args)
  784. if err != nil {
  785. panic(err)
  786. }
  787. return r, err
  788. }
  789. func (cn *conn) send(m *writeBuf) {
  790. _, err := cn.c.Write(m.wrap())
  791. if err != nil {
  792. panic(err)
  793. }
  794. }
  795. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  796. _, err := cn.c.Write((m.wrap())[1:])
  797. return err
  798. }
  799. // Send a message of type typ to the server on the other end of cn. The
  800. // message should have no payload. This method does not use the scratch
  801. // buffer.
  802. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  803. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  804. return err
  805. }
  806. // saveMessage memorizes a message and its buffer in the conn struct.
  807. // recvMessage will then return these values on the next call to it. This
  808. // method is useful in cases where you have to see what the next message is
  809. // going to be (e.g. to see whether it's an error or not) but you can't handle
  810. // the message yourself.
  811. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  812. if cn.saveMessageType != 0 {
  813. cn.bad = true
  814. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  815. }
  816. cn.saveMessageType = typ
  817. cn.saveMessageBuffer = *buf
  818. }
  819. // recvMessage receives any message from the backend, or returns an error if
  820. // a problem occurred while reading the message.
  821. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  822. // workaround for a QueryRow bug, see exec
  823. if cn.saveMessageType != 0 {
  824. t := cn.saveMessageType
  825. *r = cn.saveMessageBuffer
  826. cn.saveMessageType = 0
  827. cn.saveMessageBuffer = nil
  828. return t, nil
  829. }
  830. x := cn.scratch[:5]
  831. _, err := io.ReadFull(cn.buf, x)
  832. if err != nil {
  833. return 0, err
  834. }
  835. // read the type and length of the message that follows
  836. t := x[0]
  837. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  838. var y []byte
  839. if n <= len(cn.scratch) {
  840. y = cn.scratch[:n]
  841. } else {
  842. y = make([]byte, n)
  843. }
  844. _, err = io.ReadFull(cn.buf, y)
  845. if err != nil {
  846. return 0, err
  847. }
  848. *r = y
  849. return t, nil
  850. }
  851. // recv receives a message from the backend, but if an error happened while
  852. // reading the message or the received message was an ErrorResponse, it panics.
  853. // NoticeResponses are ignored. This function should generally be used only
  854. // during the startup sequence.
  855. func (cn *conn) recv() (t byte, r *readBuf) {
  856. for {
  857. var err error
  858. r = &readBuf{}
  859. t, err = cn.recvMessage(r)
  860. if err != nil {
  861. panic(err)
  862. }
  863. switch t {
  864. case 'E':
  865. panic(parseError(r))
  866. case 'N':
  867. // ignore
  868. default:
  869. return
  870. }
  871. }
  872. }
  873. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  874. // the caller to avoid an allocation.
  875. func (cn *conn) recv1Buf(r *readBuf) byte {
  876. for {
  877. t, err := cn.recvMessage(r)
  878. if err != nil {
  879. panic(err)
  880. }
  881. switch t {
  882. case 'A', 'N':
  883. // ignore
  884. case 'S':
  885. cn.processParameterStatus(r)
  886. default:
  887. return t
  888. }
  889. }
  890. }
  891. // recv1 receives a message from the backend, panicking if an error occurs
  892. // while attempting to read it. All asynchronous messages are ignored, with
  893. // the exception of ErrorResponse.
  894. func (cn *conn) recv1() (t byte, r *readBuf) {
  895. r = &readBuf{}
  896. t = cn.recv1Buf(r)
  897. return t, r
  898. }
  899. func (cn *conn) ssl(o values) {
  900. upgrade := ssl(o)
  901. if upgrade == nil {
  902. // Nothing to do
  903. return
  904. }
  905. w := cn.writeBuf(0)
  906. w.int32(80877103)
  907. if err := cn.sendStartupPacket(w); err != nil {
  908. panic(err)
  909. }
  910. b := cn.scratch[:1]
  911. _, err := io.ReadFull(cn.c, b)
  912. if err != nil {
  913. panic(err)
  914. }
  915. if b[0] != 'S' {
  916. panic(ErrSSLNotSupported)
  917. }
  918. cn.c = upgrade(cn.c)
  919. }
  920. // isDriverSetting returns true iff a setting is purely for configuring the
  921. // driver's options and should not be sent to the server in the connection
  922. // startup packet.
  923. func isDriverSetting(key string) bool {
  924. switch key {
  925. case "host", "port":
  926. return true
  927. case "password":
  928. return true
  929. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  930. return true
  931. case "fallback_application_name":
  932. return true
  933. case "connect_timeout":
  934. return true
  935. case "disable_prepared_binary_result":
  936. return true
  937. case "binary_parameters":
  938. return true
  939. default:
  940. return false
  941. }
  942. }
  943. func (cn *conn) startup(o values) {
  944. w := cn.writeBuf(0)
  945. w.int32(196608)
  946. // Send the backend the name of the database we want to connect to, and the
  947. // user we want to connect as. Additionally, we send over any run-time
  948. // parameters potentially included in the connection string. If the server
  949. // doesn't recognize any of them, it will reply with an error.
  950. for k, v := range o {
  951. if isDriverSetting(k) {
  952. // skip options which can't be run-time parameters
  953. continue
  954. }
  955. // The protocol requires us to supply the database name as "database"
  956. // instead of "dbname".
  957. if k == "dbname" {
  958. k = "database"
  959. }
  960. w.string(k)
  961. w.string(v)
  962. }
  963. w.string("")
  964. if err := cn.sendStartupPacket(w); err != nil {
  965. panic(err)
  966. }
  967. for {
  968. t, r := cn.recv()
  969. switch t {
  970. case 'K':
  971. cn.processBackendKeyData(r)
  972. case 'S':
  973. cn.processParameterStatus(r)
  974. case 'R':
  975. cn.auth(r, o)
  976. case 'Z':
  977. cn.processReadyForQuery(r)
  978. return
  979. default:
  980. errorf("unknown response for startup: %q", t)
  981. }
  982. }
  983. }
  984. func (cn *conn) auth(r *readBuf, o values) {
  985. switch code := r.int32(); code {
  986. case 0:
  987. // OK
  988. case 3:
  989. w := cn.writeBuf('p')
  990. w.string(o["password"])
  991. cn.send(w)
  992. t, r := cn.recv()
  993. if t != 'R' {
  994. errorf("unexpected password response: %q", t)
  995. }
  996. if r.int32() != 0 {
  997. errorf("unexpected authentication response: %q", t)
  998. }
  999. case 5:
  1000. s := string(r.next(4))
  1001. w := cn.writeBuf('p')
  1002. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  1003. cn.send(w)
  1004. t, r := cn.recv()
  1005. if t != 'R' {
  1006. errorf("unexpected password response: %q", t)
  1007. }
  1008. if r.int32() != 0 {
  1009. errorf("unexpected authentication response: %q", t)
  1010. }
  1011. default:
  1012. errorf("unknown authentication response: %d", code)
  1013. }
  1014. }
  1015. type format int
  1016. const formatText format = 0
  1017. const formatBinary format = 1
  1018. // One result-column format code with the value 1 (i.e. all binary).
  1019. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1020. // No result-column format codes (i.e. all text).
  1021. var colFmtDataAllText = []byte{0, 0}
  1022. type stmt struct {
  1023. cn *conn
  1024. name string
  1025. colNames []string
  1026. colFmts []format
  1027. colFmtData []byte
  1028. colTyps []fieldDesc
  1029. paramTyps []oid.Oid
  1030. closed bool
  1031. }
  1032. func (st *stmt) Close() (err error) {
  1033. if st.closed {
  1034. return nil
  1035. }
  1036. if st.cn.bad {
  1037. return driver.ErrBadConn
  1038. }
  1039. defer st.cn.errRecover(&err)
  1040. w := st.cn.writeBuf('C')
  1041. w.byte('S')
  1042. w.string(st.name)
  1043. st.cn.send(w)
  1044. st.cn.send(st.cn.writeBuf('S'))
  1045. t, _ := st.cn.recv1()
  1046. if t != '3' {
  1047. st.cn.bad = true
  1048. errorf("unexpected close response: %q", t)
  1049. }
  1050. st.closed = true
  1051. t, r := st.cn.recv1()
  1052. if t != 'Z' {
  1053. st.cn.bad = true
  1054. errorf("expected ready for query, but got: %q", t)
  1055. }
  1056. st.cn.processReadyForQuery(r)
  1057. return nil
  1058. }
  1059. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1060. if st.cn.bad {
  1061. return nil, driver.ErrBadConn
  1062. }
  1063. defer st.cn.errRecover(&err)
  1064. st.exec(v)
  1065. return &rows{
  1066. cn: st.cn,
  1067. colNames: st.colNames,
  1068. colTyps: st.colTyps,
  1069. colFmts: st.colFmts,
  1070. }, nil
  1071. }
  1072. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1073. if st.cn.bad {
  1074. return nil, driver.ErrBadConn
  1075. }
  1076. defer st.cn.errRecover(&err)
  1077. st.exec(v)
  1078. res, _, err = st.cn.readExecuteResponse("simple query")
  1079. return res, err
  1080. }
  1081. func (st *stmt) exec(v []driver.Value) {
  1082. if len(v) >= 65536 {
  1083. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1084. }
  1085. if len(v) != len(st.paramTyps) {
  1086. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1087. }
  1088. cn := st.cn
  1089. w := cn.writeBuf('B')
  1090. w.byte(0) // unnamed portal
  1091. w.string(st.name)
  1092. if cn.binaryParameters {
  1093. cn.sendBinaryParameters(w, v)
  1094. } else {
  1095. w.int16(0)
  1096. w.int16(len(v))
  1097. for i, x := range v {
  1098. if x == nil {
  1099. w.int32(-1)
  1100. } else {
  1101. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1102. w.int32(len(b))
  1103. w.bytes(b)
  1104. }
  1105. }
  1106. }
  1107. w.bytes(st.colFmtData)
  1108. w.next('E')
  1109. w.byte(0)
  1110. w.int32(0)
  1111. w.next('S')
  1112. cn.send(w)
  1113. cn.readBindResponse()
  1114. cn.postExecuteWorkaround()
  1115. }
  1116. func (st *stmt) NumInput() int {
  1117. return len(st.paramTyps)
  1118. }
  1119. // parseComplete parses the "command tag" from a CommandComplete message, and
  1120. // returns the number of rows affected (if applicable) and a string
  1121. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1122. // command tag could not be parsed, parseComplete panics.
  1123. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1124. commandsWithAffectedRows := []string{
  1125. "SELECT ",
  1126. // INSERT is handled below
  1127. "UPDATE ",
  1128. "DELETE ",
  1129. "FETCH ",
  1130. "MOVE ",
  1131. "COPY ",
  1132. }
  1133. var affectedRows *string
  1134. for _, tag := range commandsWithAffectedRows {
  1135. if strings.HasPrefix(commandTag, tag) {
  1136. t := commandTag[len(tag):]
  1137. affectedRows = &t
  1138. commandTag = tag[:len(tag)-1]
  1139. break
  1140. }
  1141. }
  1142. // INSERT also includes the oid of the inserted row in its command tag.
  1143. // Oids in user tables are deprecated, and the oid is only returned when
  1144. // exactly one row is inserted, so it's unlikely to be of value to any
  1145. // real-world application and we can ignore it.
  1146. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1147. parts := strings.Split(commandTag, " ")
  1148. if len(parts) != 3 {
  1149. cn.bad = true
  1150. errorf("unexpected INSERT command tag %s", commandTag)
  1151. }
  1152. affectedRows = &parts[len(parts)-1]
  1153. commandTag = "INSERT"
  1154. }
  1155. // There should be no affected rows attached to the tag, just return it
  1156. if affectedRows == nil {
  1157. return driver.RowsAffected(0), commandTag
  1158. }
  1159. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1160. if err != nil {
  1161. cn.bad = true
  1162. errorf("could not parse commandTag: %s", err)
  1163. }
  1164. return driver.RowsAffected(n), commandTag
  1165. }
  1166. type rows struct {
  1167. cn *conn
  1168. finish func()
  1169. colNames []string
  1170. colTyps []fieldDesc
  1171. colFmts []format
  1172. done bool
  1173. rb readBuf
  1174. result driver.Result
  1175. tag string
  1176. }
  1177. func (rs *rows) Close() error {
  1178. if finish := rs.finish; finish != nil {
  1179. defer finish()
  1180. }
  1181. // no need to look at cn.bad as Next() will
  1182. for {
  1183. err := rs.Next(nil)
  1184. switch err {
  1185. case nil:
  1186. case io.EOF:
  1187. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1188. // description, used with HasNextResultSet). We need to fetch messages until
  1189. // we hit a 'Z', which is done by waiting for done to be set.
  1190. if rs.done {
  1191. return nil
  1192. }
  1193. default:
  1194. return err
  1195. }
  1196. }
  1197. }
  1198. func (rs *rows) Columns() []string {
  1199. return rs.colNames
  1200. }
  1201. func (rs *rows) Result() driver.Result {
  1202. if rs.result == nil {
  1203. return emptyRows
  1204. }
  1205. return rs.result
  1206. }
  1207. func (rs *rows) Tag() string {
  1208. return rs.tag
  1209. }
  1210. func (rs *rows) Next(dest []driver.Value) (err error) {
  1211. if rs.done {
  1212. return io.EOF
  1213. }
  1214. conn := rs.cn
  1215. if conn.bad {
  1216. return driver.ErrBadConn
  1217. }
  1218. defer conn.errRecover(&err)
  1219. for {
  1220. t := conn.recv1Buf(&rs.rb)
  1221. switch t {
  1222. case 'E':
  1223. err = parseError(&rs.rb)
  1224. case 'C', 'I':
  1225. if t == 'C' {
  1226. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1227. }
  1228. continue
  1229. case 'Z':
  1230. conn.processReadyForQuery(&rs.rb)
  1231. rs.done = true
  1232. if err != nil {
  1233. return err
  1234. }
  1235. return io.EOF
  1236. case 'D':
  1237. n := rs.rb.int16()
  1238. if err != nil {
  1239. conn.bad = true
  1240. errorf("unexpected DataRow after error %s", err)
  1241. }
  1242. if n < len(dest) {
  1243. dest = dest[:n]
  1244. }
  1245. for i := range dest {
  1246. l := rs.rb.int32()
  1247. if l == -1 {
  1248. dest[i] = nil
  1249. continue
  1250. }
  1251. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1252. }
  1253. return
  1254. case 'T':
  1255. rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
  1256. return io.EOF
  1257. default:
  1258. errorf("unexpected message after execute: %q", t)
  1259. }
  1260. }
  1261. }
  1262. func (rs *rows) HasNextResultSet() bool {
  1263. return !rs.done
  1264. }
  1265. func (rs *rows) NextResultSet() error {
  1266. return nil
  1267. }
  1268. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1269. // used as part of an SQL statement. For example:
  1270. //
  1271. // tblname := "my_table"
  1272. // data := "my_data"
  1273. // quoted := pq.QuoteIdentifier(tblname)
  1274. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1275. //
  1276. // Any double quotes in name will be escaped. The quoted identifier will be
  1277. // case sensitive when used in a query. If the input string contains a zero
  1278. // byte, the result will be truncated immediately before it.
  1279. func QuoteIdentifier(name string) string {
  1280. end := strings.IndexRune(name, 0)
  1281. if end > -1 {
  1282. name = name[:end]
  1283. }
  1284. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1285. }
  1286. func md5s(s string) string {
  1287. h := md5.New()
  1288. h.Write([]byte(s))
  1289. return fmt.Sprintf("%x", h.Sum(nil))
  1290. }
  1291. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1292. // Do one pass over the parameters to see if we're going to send any of
  1293. // them over in binary. If we are, create a paramFormats array at the
  1294. // same time.
  1295. var paramFormats []int
  1296. for i, x := range args {
  1297. _, ok := x.([]byte)
  1298. if ok {
  1299. if paramFormats == nil {
  1300. paramFormats = make([]int, len(args))
  1301. }
  1302. paramFormats[i] = 1
  1303. }
  1304. }
  1305. if paramFormats == nil {
  1306. b.int16(0)
  1307. } else {
  1308. b.int16(len(paramFormats))
  1309. for _, x := range paramFormats {
  1310. b.int16(x)
  1311. }
  1312. }
  1313. b.int16(len(args))
  1314. for _, x := range args {
  1315. if x == nil {
  1316. b.int32(-1)
  1317. } else {
  1318. datum := binaryEncode(&cn.parameterStatus, x)
  1319. b.int32(len(datum))
  1320. b.bytes(datum)
  1321. }
  1322. }
  1323. }
  1324. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1325. if len(args) >= 65536 {
  1326. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1327. }
  1328. b := cn.writeBuf('P')
  1329. b.byte(0) // unnamed statement
  1330. b.string(query)
  1331. b.int16(0)
  1332. b.next('B')
  1333. b.int16(0) // unnamed portal and statement
  1334. cn.sendBinaryParameters(b, args)
  1335. b.bytes(colFmtDataAllText)
  1336. b.next('D')
  1337. b.byte('P')
  1338. b.byte(0) // unnamed portal
  1339. b.next('E')
  1340. b.byte(0)
  1341. b.int32(0)
  1342. b.next('S')
  1343. cn.send(b)
  1344. }
  1345. func (cn *conn) processParameterStatus(r *readBuf) {
  1346. var err error
  1347. param := r.string()
  1348. switch param {
  1349. case "server_version":
  1350. var major1 int
  1351. var major2 int
  1352. var minor int
  1353. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1354. if err == nil {
  1355. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1356. }
  1357. case "TimeZone":
  1358. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1359. if err != nil {
  1360. cn.parameterStatus.currentLocation = nil
  1361. }
  1362. default:
  1363. // ignore
  1364. }
  1365. }
  1366. func (cn *conn) processReadyForQuery(r *readBuf) {
  1367. cn.txnStatus = transactionStatus(r.byte())
  1368. }
  1369. func (cn *conn) readReadyForQuery() {
  1370. t, r := cn.recv1()
  1371. switch t {
  1372. case 'Z':
  1373. cn.processReadyForQuery(r)
  1374. return
  1375. default:
  1376. cn.bad = true
  1377. errorf("unexpected message %q; expected ReadyForQuery", t)
  1378. }
  1379. }
  1380. func (cn *conn) processBackendKeyData(r *readBuf) {
  1381. cn.processID = r.int32()
  1382. cn.secretKey = r.int32()
  1383. }
  1384. func (cn *conn) readParseResponse() {
  1385. t, r := cn.recv1()
  1386. switch t {
  1387. case '1':
  1388. return
  1389. case 'E':
  1390. err := parseError(r)
  1391. cn.readReadyForQuery()
  1392. panic(err)
  1393. default:
  1394. cn.bad = true
  1395. errorf("unexpected Parse response %q", t)
  1396. }
  1397. }
  1398. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1399. for {
  1400. t, r := cn.recv1()
  1401. switch t {
  1402. case 't':
  1403. nparams := r.int16()
  1404. paramTyps = make([]oid.Oid, nparams)
  1405. for i := range paramTyps {
  1406. paramTyps[i] = r.oid()
  1407. }
  1408. case 'n':
  1409. return paramTyps, nil, nil
  1410. case 'T':
  1411. colNames, colTyps = parseStatementRowDescribe(r)
  1412. return paramTyps, colNames, colTyps
  1413. case 'E':
  1414. err := parseError(r)
  1415. cn.readReadyForQuery()
  1416. panic(err)
  1417. default:
  1418. cn.bad = true
  1419. errorf("unexpected Describe statement response %q", t)
  1420. }
  1421. }
  1422. }
  1423. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1424. t, r := cn.recv1()
  1425. switch t {
  1426. case 'T':
  1427. return parsePortalRowDescribe(r)
  1428. case 'n':
  1429. return nil, nil, nil
  1430. case 'E':
  1431. err := parseError(r)
  1432. cn.readReadyForQuery()
  1433. panic(err)
  1434. default:
  1435. cn.bad = true
  1436. errorf("unexpected Describe response %q", t)
  1437. }
  1438. panic("not reached")
  1439. }
  1440. func (cn *conn) readBindResponse() {
  1441. t, r := cn.recv1()
  1442. switch t {
  1443. case '2':
  1444. return
  1445. case 'E':
  1446. err := parseError(r)
  1447. cn.readReadyForQuery()
  1448. panic(err)
  1449. default:
  1450. cn.bad = true
  1451. errorf("unexpected Bind response %q", t)
  1452. }
  1453. }
  1454. func (cn *conn) postExecuteWorkaround() {
  1455. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1456. // any errors from rows.Next, which masks errors that happened during the
  1457. // execution of the query. To avoid the problem in common cases, we wait
  1458. // here for one more message from the database. If it's not an error the
  1459. // query will likely succeed (or perhaps has already, if it's a
  1460. // CommandComplete), so we push the message into the conn struct; recv1
  1461. // will return it as the next message for rows.Next or rows.Close.
  1462. // However, if it's an error, we wait until ReadyForQuery and then return
  1463. // the error to our caller.
  1464. for {
  1465. t, r := cn.recv1()
  1466. switch t {
  1467. case 'E':
  1468. err := parseError(r)
  1469. cn.readReadyForQuery()
  1470. panic(err)
  1471. case 'C', 'D', 'I':
  1472. // the query didn't fail, but we can't process this message
  1473. cn.saveMessage(t, r)
  1474. return
  1475. default:
  1476. cn.bad = true
  1477. errorf("unexpected message during extended query execution: %q", t)
  1478. }
  1479. }
  1480. }
  1481. // Only for Exec(), since we ignore the returned data
  1482. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1483. for {
  1484. t, r := cn.recv1()
  1485. switch t {
  1486. case 'C':
  1487. if err != nil {
  1488. cn.bad = true
  1489. errorf("unexpected CommandComplete after error %s", err)
  1490. }
  1491. res, commandTag = cn.parseComplete(r.string())
  1492. case 'Z':
  1493. cn.processReadyForQuery(r)
  1494. if res == nil && err == nil {
  1495. err = errUnexpectedReady
  1496. }
  1497. return res, commandTag, err
  1498. case 'E':
  1499. err = parseError(r)
  1500. case 'T', 'D', 'I':
  1501. if err != nil {
  1502. cn.bad = true
  1503. errorf("unexpected %q after error %s", t, err)
  1504. }
  1505. if t == 'I' {
  1506. res = emptyRows
  1507. }
  1508. // ignore any results
  1509. default:
  1510. cn.bad = true
  1511. errorf("unknown %s response: %q", protocolState, t)
  1512. }
  1513. }
  1514. }
  1515. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1516. n := r.int16()
  1517. colNames = make([]string, n)
  1518. colTyps = make([]fieldDesc, n)
  1519. for i := range colNames {
  1520. colNames[i] = r.string()
  1521. r.next(6)
  1522. colTyps[i].OID = r.oid()
  1523. colTyps[i].Len = r.int16()
  1524. colTyps[i].Mod = r.int32()
  1525. // format code not known when describing a statement; always 0
  1526. r.next(2)
  1527. }
  1528. return
  1529. }
  1530. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1531. n := r.int16()
  1532. colNames = make([]string, n)
  1533. colFmts = make([]format, n)
  1534. colTyps = make([]fieldDesc, n)
  1535. for i := range colNames {
  1536. colNames[i] = r.string()
  1537. r.next(6)
  1538. colTyps[i].OID = r.oid()
  1539. colTyps[i].Len = r.int16()
  1540. colTyps[i].Mod = r.int32()
  1541. colFmts[i] = format(r.int16())
  1542. }
  1543. return
  1544. }
  1545. // parseEnviron tries to mimic some of libpq's environment handling
  1546. //
  1547. // To ease testing, it does not directly reference os.Environ, but is
  1548. // designed to accept its output.
  1549. //
  1550. // Environment-set connection information is intended to have a higher
  1551. // precedence than a library default but lower than any explicitly
  1552. // passed information (such as in the URL or connection string).
  1553. func parseEnviron(env []string) (out map[string]string) {
  1554. out = make(map[string]string)
  1555. for _, v := range env {
  1556. parts := strings.SplitN(v, "=", 2)
  1557. accrue := func(keyname string) {
  1558. out[keyname] = parts[1]
  1559. }
  1560. unsupported := func() {
  1561. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1562. }
  1563. // The order of these is the same as is seen in the
  1564. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1565. // keys cause a panic; these should be unset prior to
  1566. // execution. Options which pq expects to be set to a
  1567. // certain value are allowed, but must be set to that
  1568. // value if present (they can, of course, be absent).
  1569. switch parts[0] {
  1570. case "PGHOST":
  1571. accrue("host")
  1572. case "PGHOSTADDR":
  1573. unsupported()
  1574. case "PGPORT":
  1575. accrue("port")
  1576. case "PGDATABASE":
  1577. accrue("dbname")
  1578. case "PGUSER":
  1579. accrue("user")
  1580. case "PGPASSWORD":
  1581. accrue("password")
  1582. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1583. unsupported()
  1584. case "PGOPTIONS":
  1585. accrue("options")
  1586. case "PGAPPNAME":
  1587. accrue("application_name")
  1588. case "PGSSLMODE":
  1589. accrue("sslmode")
  1590. case "PGSSLCERT":
  1591. accrue("sslcert")
  1592. case "PGSSLKEY":
  1593. accrue("sslkey")
  1594. case "PGSSLROOTCERT":
  1595. accrue("sslrootcert")
  1596. case "PGREQUIRESSL", "PGSSLCRL":
  1597. unsupported()
  1598. case "PGREQUIREPEER":
  1599. unsupported()
  1600. case "PGKRBSRVNAME", "PGGSSLIB":
  1601. unsupported()
  1602. case "PGCONNECT_TIMEOUT":
  1603. accrue("connect_timeout")
  1604. case "PGCLIENTENCODING":
  1605. accrue("client_encoding")
  1606. case "PGDATESTYLE":
  1607. accrue("datestyle")
  1608. case "PGTZ":
  1609. accrue("timezone")
  1610. case "PGGEQO":
  1611. accrue("geqo")
  1612. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1613. unsupported()
  1614. }
  1615. }
  1616. return out
  1617. }
  1618. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1619. func isUTF8(name string) bool {
  1620. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1621. s := strings.Map(alnumLowerASCII, name)
  1622. return s == "utf8" || s == "unicode"
  1623. }
  1624. func alnumLowerASCII(ch rune) rune {
  1625. if 'A' <= ch && ch <= 'Z' {
  1626. return ch + ('a' - 'A')
  1627. }
  1628. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1629. return ch
  1630. }
  1631. return -1 // discard
  1632. }