copy.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package pq
  2. import (
  3. "database/sql/driver"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. )
  9. var (
  10. errCopyInClosed = errors.New("pq: copyin statement has already been closed")
  11. errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
  12. errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
  13. errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
  14. )
  15. // CopyIn creates a COPY FROM statement which can be prepared with
  16. // Tx.Prepare(). The target table should be visible in search_path.
  17. func CopyIn(table string, columns ...string) string {
  18. stmt := "COPY " + QuoteIdentifier(table) + " ("
  19. for i, col := range columns {
  20. if i != 0 {
  21. stmt += ", "
  22. }
  23. stmt += QuoteIdentifier(col)
  24. }
  25. stmt += ") FROM STDIN"
  26. return stmt
  27. }
  28. // CopyInSchema creates a COPY FROM statement which can be prepared with
  29. // Tx.Prepare().
  30. func CopyInSchema(schema, table string, columns ...string) string {
  31. stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
  32. for i, col := range columns {
  33. if i != 0 {
  34. stmt += ", "
  35. }
  36. stmt += QuoteIdentifier(col)
  37. }
  38. stmt += ") FROM STDIN"
  39. return stmt
  40. }
  41. type copyin struct {
  42. cn *conn
  43. buffer []byte
  44. rowData chan []byte
  45. done chan bool
  46. closed bool
  47. sync.Mutex // guards err
  48. err error
  49. }
  50. const ciBufferSize = 64 * 1024
  51. // flush buffer before the buffer is filled up and needs reallocation
  52. const ciBufferFlushSize = 63 * 1024
  53. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  54. if !cn.isInTransaction() {
  55. return nil, errCopyNotSupportedOutsideTxn
  56. }
  57. ci := &copyin{
  58. cn: cn,
  59. buffer: make([]byte, 0, ciBufferSize),
  60. rowData: make(chan []byte),
  61. done: make(chan bool, 1),
  62. }
  63. // add CopyData identifier + 4 bytes for message length
  64. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  65. b := cn.writeBuf('Q')
  66. b.string(q)
  67. cn.send(b)
  68. awaitCopyInResponse:
  69. for {
  70. t, r := cn.recv1()
  71. switch t {
  72. case 'G':
  73. if r.byte() != 0 {
  74. err = errBinaryCopyNotSupported
  75. break awaitCopyInResponse
  76. }
  77. go ci.resploop()
  78. return ci, nil
  79. case 'H':
  80. err = errCopyToNotSupported
  81. break awaitCopyInResponse
  82. case 'E':
  83. err = parseError(r)
  84. case 'Z':
  85. if err == nil {
  86. cn.bad = true
  87. errorf("unexpected ReadyForQuery in response to COPY")
  88. }
  89. cn.processReadyForQuery(r)
  90. return nil, err
  91. default:
  92. cn.bad = true
  93. errorf("unknown response for copy query: %q", t)
  94. }
  95. }
  96. // something went wrong, abort COPY before we return
  97. b = cn.writeBuf('f')
  98. b.string(err.Error())
  99. cn.send(b)
  100. for {
  101. t, r := cn.recv1()
  102. switch t {
  103. case 'c', 'C', 'E':
  104. case 'Z':
  105. // correctly aborted, we're done
  106. cn.processReadyForQuery(r)
  107. return nil, err
  108. default:
  109. cn.bad = true
  110. errorf("unknown response for CopyFail: %q", t)
  111. }
  112. }
  113. }
  114. func (ci *copyin) flush(buf []byte) {
  115. // set message length (without message identifier)
  116. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  117. _, err := ci.cn.c.Write(buf)
  118. if err != nil {
  119. panic(err)
  120. }
  121. }
  122. func (ci *copyin) resploop() {
  123. for {
  124. var r readBuf
  125. t, err := ci.cn.recvMessage(&r)
  126. if err != nil {
  127. ci.cn.bad = true
  128. ci.setError(err)
  129. ci.done <- true
  130. return
  131. }
  132. switch t {
  133. case 'C':
  134. // complete
  135. case 'Z':
  136. ci.cn.processReadyForQuery(&r)
  137. ci.done <- true
  138. return
  139. case 'E':
  140. err := parseError(&r)
  141. ci.setError(err)
  142. default:
  143. ci.cn.bad = true
  144. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  145. ci.done <- true
  146. return
  147. }
  148. }
  149. }
  150. func (ci *copyin) isErrorSet() bool {
  151. ci.Lock()
  152. isSet := (ci.err != nil)
  153. ci.Unlock()
  154. return isSet
  155. }
  156. // setError() sets ci.err if one has not been set already. Caller must not be
  157. // holding ci.Mutex.
  158. func (ci *copyin) setError(err error) {
  159. ci.Lock()
  160. if ci.err == nil {
  161. ci.err = err
  162. }
  163. ci.Unlock()
  164. }
  165. func (ci *copyin) NumInput() int {
  166. return -1
  167. }
  168. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  169. return nil, ErrNotSupported
  170. }
  171. // Exec inserts values into the COPY stream. The insert is asynchronous
  172. // and Exec can return errors from previous Exec calls to the same
  173. // COPY stmt.
  174. //
  175. // You need to call Exec(nil) to sync the COPY stream and to get any
  176. // errors from pending data, since Stmt.Close() doesn't return errors
  177. // to the user.
  178. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  179. if ci.closed {
  180. return nil, errCopyInClosed
  181. }
  182. if ci.cn.bad {
  183. return nil, driver.ErrBadConn
  184. }
  185. defer ci.cn.errRecover(&err)
  186. if ci.isErrorSet() {
  187. return nil, ci.err
  188. }
  189. if len(v) == 0 {
  190. err = ci.Close()
  191. ci.closed = true
  192. return nil, err
  193. }
  194. numValues := len(v)
  195. for i, value := range v {
  196. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  197. if i < numValues-1 {
  198. ci.buffer = append(ci.buffer, '\t')
  199. }
  200. }
  201. ci.buffer = append(ci.buffer, '\n')
  202. if len(ci.buffer) > ciBufferFlushSize {
  203. ci.flush(ci.buffer)
  204. // reset buffer, keep bytes for message identifier and length
  205. ci.buffer = ci.buffer[:5]
  206. }
  207. return driver.RowsAffected(0), nil
  208. }
  209. func (ci *copyin) Close() (err error) {
  210. if ci.closed {
  211. return errCopyInClosed
  212. }
  213. if ci.cn.bad {
  214. return driver.ErrBadConn
  215. }
  216. defer ci.cn.errRecover(&err)
  217. if len(ci.buffer) > 0 {
  218. ci.flush(ci.buffer)
  219. }
  220. // Avoid touching the scratch buffer as resploop could be using it.
  221. err = ci.cn.sendSimpleMessage('c')
  222. if err != nil {
  223. return err
  224. }
  225. <-ci.done
  226. if ci.isErrorSet() {
  227. err = ci.err
  228. return err
  229. }
  230. return nil
  231. }