encode.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. package pq
  2. import (
  3. "bytes"
  4. "database/sql/driver"
  5. "encoding/hex"
  6. "fmt"
  7. "github.com/lib/pq/oid"
  8. "math"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
  15. switch v := x.(type) {
  16. case int64:
  17. return []byte(fmt.Sprintf("%d", v))
  18. case float32:
  19. return []byte(fmt.Sprintf("%.9f", v))
  20. case float64:
  21. return []byte(fmt.Sprintf("%.17f", v))
  22. case []byte:
  23. if pgtypOid == oid.T_bytea {
  24. return encodeBytea(parameterStatus.serverVersion, v)
  25. }
  26. return v
  27. case string:
  28. if pgtypOid == oid.T_bytea {
  29. return encodeBytea(parameterStatus.serverVersion, []byte(v))
  30. }
  31. return []byte(v)
  32. case bool:
  33. return []byte(fmt.Sprintf("%t", v))
  34. case time.Time:
  35. return formatTs(v)
  36. default:
  37. errorf("encode: unknown type for %T", v)
  38. }
  39. panic("not reached")
  40. }
  41. func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
  42. switch typ {
  43. case oid.T_bytea:
  44. return parseBytea(s)
  45. case oid.T_timestamptz:
  46. return parseTs(parameterStatus.currentLocation, string(s))
  47. case oid.T_timestamp, oid.T_date:
  48. return parseTs(nil, string(s))
  49. case oid.T_time:
  50. return mustParse("15:04:05", typ, s)
  51. case oid.T_timetz:
  52. return mustParse("15:04:05-07", typ, s)
  53. case oid.T_bool:
  54. return s[0] == 't'
  55. case oid.T_int8, oid.T_int2, oid.T_int4:
  56. i, err := strconv.ParseInt(string(s), 10, 64)
  57. if err != nil {
  58. errorf("%s", err)
  59. }
  60. return i
  61. case oid.T_float4, oid.T_float8:
  62. bits := 64
  63. if typ == oid.T_float4 {
  64. bits = 32
  65. }
  66. f, err := strconv.ParseFloat(string(s), bits)
  67. if err != nil {
  68. errorf("%s", err)
  69. }
  70. return f
  71. }
  72. return s
  73. }
  74. // appendEncodedText encodes item in text format as required by COPY
  75. // and appends to buf
  76. func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
  77. switch v := x.(type) {
  78. case int64:
  79. return strconv.AppendInt(buf, v, 10)
  80. case float32:
  81. return strconv.AppendFloat(buf, float64(v), 'f', -1, 32)
  82. case float64:
  83. return strconv.AppendFloat(buf, v, 'f', -1, 64)
  84. case []byte:
  85. encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
  86. return appendEscapedText(buf, string(encodedBytea))
  87. case string:
  88. return appendEscapedText(buf, v)
  89. case bool:
  90. return strconv.AppendBool(buf, v)
  91. case time.Time:
  92. return append(buf, formatTs(v)...)
  93. case nil:
  94. return append(buf, "\\N"...)
  95. default:
  96. errorf("encode: unknown type for %T", v)
  97. }
  98. panic("not reached")
  99. }
  100. func appendEscapedText(buf []byte, text string) []byte {
  101. escapeNeeded := false
  102. startPos := 0
  103. var c byte
  104. // check if we need to escape
  105. for i := 0; i < len(text); i++ {
  106. c = text[i]
  107. if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
  108. escapeNeeded = true
  109. startPos = i
  110. break
  111. }
  112. }
  113. if !escapeNeeded {
  114. return append(buf, text...)
  115. }
  116. // copy till first char to escape, iterate the rest
  117. result := append(buf, text[:startPos]...)
  118. for i := startPos; i < len(text); i++ {
  119. c = text[i]
  120. switch c {
  121. case '\\':
  122. result = append(result, '\\', '\\')
  123. case '\n':
  124. result = append(result, '\\', 'n')
  125. case '\r':
  126. result = append(result, '\\', 'r')
  127. case '\t':
  128. result = append(result, '\\', 't')
  129. default:
  130. result = append(result, c)
  131. }
  132. }
  133. return result
  134. }
  135. func mustParse(f string, typ oid.Oid, s []byte) time.Time {
  136. str := string(s)
  137. // Special case until time.Parse bug is fixed:
  138. // http://code.google.com/p/go/issues/detail?id=3487
  139. if str[len(str)-2] == '.' {
  140. str += "0"
  141. }
  142. // check for a 30-minute-offset timezone
  143. if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
  144. str[len(str)-3] == ':' {
  145. f += ":00"
  146. }
  147. t, err := time.Parse(f, str)
  148. if err != nil {
  149. errorf("decode: %s", err)
  150. }
  151. return t
  152. }
  153. func expect(str, char string, pos int) {
  154. if c := str[pos : pos+1]; c != char {
  155. errorf("expected '%v' at position %v; got '%v'", char, pos, c)
  156. }
  157. }
  158. func mustAtoi(str string) int {
  159. result, err := strconv.Atoi(str)
  160. if err != nil {
  161. errorf("expected number; got '%v'", str)
  162. }
  163. return result
  164. }
  165. // The location cache caches the time zones typically used by the client.
  166. type locationCache struct {
  167. cache map[int]*time.Location
  168. lock sync.Mutex
  169. }
  170. // All connections share the same list of timezones. Benchmarking shows that
  171. // about 5% speed could be gained by putting the cache in the connection and
  172. // losing the mutex, at the cost of a small amount of memory and a somewhat
  173. // significant increase in code complexity.
  174. var globalLocationCache *locationCache = newLocationCache()
  175. func newLocationCache() *locationCache {
  176. return &locationCache{cache: make(map[int]*time.Location)}
  177. }
  178. // Returns the cached timezone for the specified offset, creating and caching
  179. // it if necessary.
  180. func (c *locationCache) getLocation(offset int) *time.Location {
  181. c.lock.Lock()
  182. defer c.lock.Unlock()
  183. location, ok := c.cache[offset]
  184. if !ok {
  185. location = time.FixedZone("", offset)
  186. c.cache[offset] = location
  187. }
  188. return location
  189. }
  190. // This is a time function specific to the Postgres default DateStyle
  191. // setting ("ISO, MDY"), the only one we currently support. This
  192. // accounts for the discrepancies between the parsing available with
  193. // time.Parse and the Postgres date formatting quirks.
  194. func parseTs(currentLocation *time.Location, str string) (result time.Time) {
  195. monSep := strings.IndexRune(str, '-')
  196. // this is Gregorian year, not ISO Year
  197. // In Gregorian system, the year 1 BC is followed by AD 1
  198. year := mustAtoi(str[:monSep])
  199. daySep := monSep + 3
  200. month := mustAtoi(str[monSep+1 : daySep])
  201. expect(str, "-", daySep)
  202. timeSep := daySep + 3
  203. day := mustAtoi(str[daySep+1 : timeSep])
  204. var hour, minute, second int
  205. if len(str) > monSep+len("01-01")+1 {
  206. expect(str, " ", timeSep)
  207. minSep := timeSep + 3
  208. expect(str, ":", minSep)
  209. hour = mustAtoi(str[timeSep+1 : minSep])
  210. secSep := minSep + 3
  211. expect(str, ":", secSep)
  212. minute = mustAtoi(str[minSep+1 : secSep])
  213. secEnd := secSep + 3
  214. second = mustAtoi(str[secSep+1 : secEnd])
  215. }
  216. remainderIdx := monSep + len("01-01 00:00:00") + 1
  217. // Three optional (but ordered) sections follow: the
  218. // fractional seconds, the time zone offset, and the BC
  219. // designation. We set them up here and adjust the other
  220. // offsets if the preceding sections exist.
  221. nanoSec := 0
  222. tzOff := 0
  223. if remainderIdx < len(str) && str[remainderIdx:remainderIdx+1] == "." {
  224. fracStart := remainderIdx + 1
  225. fracOff := strings.IndexAny(str[fracStart:], "-+ ")
  226. if fracOff < 0 {
  227. fracOff = len(str) - fracStart
  228. }
  229. fracSec := mustAtoi(str[fracStart : fracStart+fracOff])
  230. nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
  231. remainderIdx += fracOff + 1
  232. }
  233. if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") {
  234. // time zone separator is always '-' or '+' (UTC is +00)
  235. var tzSign int
  236. if c := str[tzStart : tzStart+1]; c == "-" {
  237. tzSign = -1
  238. } else if c == "+" {
  239. tzSign = +1
  240. } else {
  241. errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
  242. }
  243. tzHours := mustAtoi(str[tzStart+1 : tzStart+3])
  244. remainderIdx += 3
  245. var tzMin, tzSec int
  246. if tzStart+3 < len(str) && str[tzStart+3:tzStart+4] == ":" {
  247. tzMin = mustAtoi(str[tzStart+4 : tzStart+6])
  248. remainderIdx += 3
  249. }
  250. if tzStart+6 < len(str) && str[tzStart+6:tzStart+7] == ":" {
  251. tzSec = mustAtoi(str[tzStart+7 : tzStart+9])
  252. remainderIdx += 3
  253. }
  254. tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
  255. }
  256. var isoYear int
  257. if remainderIdx < len(str) && str[remainderIdx:remainderIdx+3] == " BC" {
  258. isoYear = 1 - year
  259. remainderIdx += 3
  260. } else {
  261. isoYear = year
  262. }
  263. if remainderIdx < len(str) {
  264. errorf("expected end of input, got %v", str[remainderIdx:])
  265. }
  266. t := time.Date(isoYear, time.Month(month), day,
  267. hour, minute, second, nanoSec,
  268. globalLocationCache.getLocation(tzOff))
  269. if currentLocation != nil {
  270. // Set the location of the returned Time based on the session's
  271. // TimeZone value, but only if the local time zone database agrees with
  272. // the remote database on the offset.
  273. lt := t.In(currentLocation)
  274. _, newOff := lt.Zone()
  275. if newOff == tzOff {
  276. t = lt
  277. }
  278. }
  279. return t
  280. }
  281. // formatTs formats t as time.RFC3339Nano and appends time zone seconds if
  282. // needed.
  283. func formatTs(t time.Time) (b []byte) {
  284. b = []byte(t.Format(time.RFC3339Nano))
  285. // Need to send dates before 0001 A.D. with " BC" suffix, instead of the
  286. // minus sign preferred by Go.
  287. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
  288. bc := false
  289. if t.Year() <= 0 {
  290. // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
  291. t = t.AddDate((-t.Year())*2+1, 0, 0)
  292. bc = true
  293. }
  294. b = []byte(t.Format(time.RFC3339Nano))
  295. if bc {
  296. b = append(b, " BC"...)
  297. }
  298. _, offset := t.Zone()
  299. offset = offset % 60
  300. if offset == 0 {
  301. return b
  302. }
  303. if offset < 0 {
  304. offset = -offset
  305. }
  306. b = append(b, ':')
  307. if offset < 10 {
  308. b = append(b, '0')
  309. }
  310. return strconv.AppendInt(b, int64(offset), 10)
  311. }
  312. // Parse a bytea value received from the server. Both "hex" and the legacy
  313. // "escape" format are supported.
  314. func parseBytea(s []byte) (result []byte) {
  315. if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
  316. // bytea_output = hex
  317. s = s[2:] // trim off leading "\\x"
  318. result = make([]byte, hex.DecodedLen(len(s)))
  319. _, err := hex.Decode(result, s)
  320. if err != nil {
  321. errorf("%s", err)
  322. }
  323. } else {
  324. // bytea_output = escape
  325. for len(s) > 0 {
  326. if s[0] == '\\' {
  327. // escaped '\\'
  328. if len(s) >= 2 && s[1] == '\\' {
  329. result = append(result, '\\')
  330. s = s[2:]
  331. continue
  332. }
  333. // '\\' followed by an octal number
  334. if len(s) < 4 {
  335. errorf("invalid bytea sequence %v", s)
  336. }
  337. r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
  338. if err != nil {
  339. errorf("could not parse bytea value: %s", err.Error())
  340. }
  341. result = append(result, byte(r))
  342. s = s[4:]
  343. } else {
  344. // We hit an unescaped, raw byte. Try to read in as many as
  345. // possible in one go.
  346. i := bytes.IndexByte(s, '\\')
  347. if i == -1 {
  348. result = append(result, s...)
  349. break
  350. }
  351. result = append(result, s[:i]...)
  352. s = s[i:]
  353. }
  354. }
  355. }
  356. return result
  357. }
  358. func encodeBytea(serverVersion int, v []byte) (result []byte) {
  359. if serverVersion >= 90000 {
  360. // Use the hex format if we know that the server supports it
  361. result = []byte(fmt.Sprintf("\\x%x", v))
  362. } else {
  363. // .. or resort to "escape"
  364. for _, b := range v {
  365. if b == '\\' {
  366. result = append(result, '\\', '\\')
  367. } else if b < 0x20 || b > 0x7e {
  368. result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
  369. } else {
  370. result = append(result, b)
  371. }
  372. }
  373. }
  374. return result
  375. }
  376. // NullTime represents a time.Time that may be null. NullTime implements the
  377. // sql.Scanner interface so it can be used as a scan destination, similar to
  378. // sql.NullString.
  379. type NullTime struct {
  380. Time time.Time
  381. Valid bool // Valid is true if Time is not NULL
  382. }
  383. // Scan implements the Scanner interface.
  384. func (nt *NullTime) Scan(value interface{}) error {
  385. nt.Time, nt.Valid = value.(time.Time)
  386. return nil
  387. }
  388. // Value implements the driver Valuer interface.
  389. func (nt NullTime) Value() (driver.Value, error) {
  390. if !nt.Valid {
  391. return nil, nil
  392. }
  393. return nt.Time, nil
  394. }