db.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. package core
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "errors"
  6. "reflect"
  7. "regexp"
  8. "sync"
  9. )
  10. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  11. vv := reflect.ValueOf(mp)
  12. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  13. return "", []interface{}{}, ErrNoMapPointer
  14. }
  15. args := make([]interface{}, 0)
  16. query = re.ReplaceAllStringFunc(query, func(src string) string {
  17. args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface())
  18. return "?"
  19. })
  20. return query, args, nil
  21. }
  22. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  23. vv := reflect.ValueOf(st)
  24. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  25. return "", []interface{}{}, ErrNoStructPointer
  26. }
  27. args := make([]interface{}, 0)
  28. var err error
  29. query = re.ReplaceAllStringFunc(query, func(src string) string {
  30. fv := vv.Elem().FieldByName(src[1:]).Interface()
  31. if v, ok := fv.(driver.Valuer); ok {
  32. var value driver.Value
  33. value, err = v.Value()
  34. if err != nil {
  35. return "?"
  36. }
  37. args = append(args, value)
  38. } else {
  39. args = append(args, fv)
  40. }
  41. return "?"
  42. })
  43. if err != nil {
  44. return "", []interface{}{}, err
  45. }
  46. return query, args, nil
  47. }
  48. type DB struct {
  49. *sql.DB
  50. Mapper IMapper
  51. }
  52. func Open(driverName, dataSourceName string) (*DB, error) {
  53. db, err := sql.Open(driverName, dataSourceName)
  54. if err != nil {
  55. return nil, err
  56. }
  57. return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
  58. }
  59. func FromDB(db *sql.DB) *DB {
  60. return &DB{db, NewCacheMapper(&SnakeMapper{})}
  61. }
  62. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  63. rows, err := db.DB.Query(query, args...)
  64. if err != nil {
  65. if rows != nil {
  66. rows.Close()
  67. }
  68. return nil, err
  69. }
  70. return &Rows{rows, db.Mapper}, nil
  71. }
  72. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  73. query, args, err := MapToSlice(query, mp)
  74. if err != nil {
  75. return nil, err
  76. }
  77. return db.Query(query, args...)
  78. }
  79. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  80. query, args, err := StructToSlice(query, st)
  81. if err != nil {
  82. return nil, err
  83. }
  84. return db.Query(query, args...)
  85. }
  86. type Row struct {
  87. rows *Rows
  88. // One of these two will be non-nil:
  89. err error // deferred error for easy chaining
  90. }
  91. func (row *Row) Columns() ([]string, error) {
  92. if row.err != nil {
  93. return nil, row.err
  94. }
  95. return row.rows.Columns()
  96. }
  97. func (row *Row) Scan(dest ...interface{}) error {
  98. if row.err != nil {
  99. return row.err
  100. }
  101. defer row.rows.Close()
  102. for _, dp := range dest {
  103. if _, ok := dp.(*sql.RawBytes); ok {
  104. return errors.New("sql: RawBytes isn't allowed on Row.Scan")
  105. }
  106. }
  107. if !row.rows.Next() {
  108. if err := row.rows.Err(); err != nil {
  109. return err
  110. }
  111. return sql.ErrNoRows
  112. }
  113. err := row.rows.Scan(dest...)
  114. if err != nil {
  115. return err
  116. }
  117. // Make sure the query can be processed to completion with no errors.
  118. if err := row.rows.Close(); err != nil {
  119. return err
  120. }
  121. return nil
  122. }
  123. func (row *Row) ScanStructByName(dest interface{}) error {
  124. if row.err != nil {
  125. return row.err
  126. }
  127. return row.rows.ScanStructByName(dest)
  128. }
  129. func (row *Row) ScanStructByIndex(dest interface{}) error {
  130. if row.err != nil {
  131. return row.err
  132. }
  133. return row.rows.ScanStructByIndex(dest)
  134. }
  135. // scan data to a slice's pointer, slice's length should equal to columns' number
  136. func (row *Row) ScanSlice(dest interface{}) error {
  137. if row.err != nil {
  138. return row.err
  139. }
  140. return row.rows.ScanSlice(dest)
  141. }
  142. // scan data to a map's pointer
  143. func (row *Row) ScanMap(dest interface{}) error {
  144. if row.err != nil {
  145. return row.err
  146. }
  147. return row.rows.ScanMap(dest)
  148. }
  149. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  150. rows, err := db.Query(query, args...)
  151. return &Row{rows, err}
  152. }
  153. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  154. query, args, err := MapToSlice(query, mp)
  155. if err != nil {
  156. return &Row{nil, err}
  157. }
  158. return db.QueryRow(query, args...)
  159. }
  160. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  161. query, args, err := StructToSlice(query, st)
  162. if err != nil {
  163. return &Row{nil, err}
  164. }
  165. return db.QueryRow(query, args...)
  166. }
  167. type Stmt struct {
  168. *sql.Stmt
  169. Mapper IMapper
  170. names map[string]int
  171. }
  172. func (db *DB) Prepare(query string) (*Stmt, error) {
  173. names := make(map[string]int)
  174. var i int
  175. query = re.ReplaceAllStringFunc(query, func(src string) string {
  176. names[src[1:]] = i
  177. i += 1
  178. return "?"
  179. })
  180. stmt, err := db.DB.Prepare(query)
  181. if err != nil {
  182. return nil, err
  183. }
  184. return &Stmt{stmt, db.Mapper, names}, nil
  185. }
  186. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  187. vv := reflect.ValueOf(mp)
  188. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  189. return nil, errors.New("mp should be a map's pointer")
  190. }
  191. args := make([]interface{}, len(s.names))
  192. for k, i := range s.names {
  193. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  194. }
  195. return s.Stmt.Exec(args...)
  196. }
  197. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  198. vv := reflect.ValueOf(st)
  199. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  200. return nil, errors.New("mp should be a map's pointer")
  201. }
  202. args := make([]interface{}, len(s.names))
  203. for k, i := range s.names {
  204. args[i] = vv.Elem().FieldByName(k).Interface()
  205. }
  206. return s.Stmt.Exec(args...)
  207. }
  208. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  209. rows, err := s.Stmt.Query(args...)
  210. if err != nil {
  211. return nil, err
  212. }
  213. return &Rows{rows, s.Mapper}, nil
  214. }
  215. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  216. vv := reflect.ValueOf(mp)
  217. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  218. return nil, errors.New("mp should be a map's pointer")
  219. }
  220. args := make([]interface{}, len(s.names))
  221. for k, i := range s.names {
  222. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  223. }
  224. return s.Query(args...)
  225. }
  226. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  227. vv := reflect.ValueOf(st)
  228. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  229. return nil, errors.New("mp should be a map's pointer")
  230. }
  231. args := make([]interface{}, len(s.names))
  232. for k, i := range s.names {
  233. args[i] = vv.Elem().FieldByName(k).Interface()
  234. }
  235. return s.Query(args...)
  236. }
  237. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  238. rows, err := s.Query(args...)
  239. return &Row{rows, err}
  240. }
  241. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  242. vv := reflect.ValueOf(mp)
  243. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  244. return &Row{nil, errors.New("mp should be a map's pointer")}
  245. }
  246. args := make([]interface{}, len(s.names))
  247. for k, i := range s.names {
  248. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  249. }
  250. return s.QueryRow(args...)
  251. }
  252. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  253. vv := reflect.ValueOf(st)
  254. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  255. return &Row{nil, errors.New("st should be a struct's pointer")}
  256. }
  257. args := make([]interface{}, len(s.names))
  258. for k, i := range s.names {
  259. args[i] = vv.Elem().FieldByName(k).Interface()
  260. }
  261. return s.QueryRow(args...)
  262. }
  263. var (
  264. re = regexp.MustCompile(`[?](\w+)`)
  265. )
  266. // insert into (name) values (?)
  267. // insert into (name) values (?name)
  268. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  269. query, args, err := MapToSlice(query, mp)
  270. if err != nil {
  271. return nil, err
  272. }
  273. return db.DB.Exec(query, args...)
  274. }
  275. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  276. query, args, err := StructToSlice(query, st)
  277. if err != nil {
  278. return nil, err
  279. }
  280. return db.DB.Exec(query, args...)
  281. }
  282. type Rows struct {
  283. *sql.Rows
  284. Mapper IMapper
  285. }
  286. // scan data to a struct's pointer according field index
  287. func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
  288. if len(dest) == 0 {
  289. return errors.New("at least one struct")
  290. }
  291. vvvs := make([]reflect.Value, len(dest))
  292. for i, s := range dest {
  293. vv := reflect.ValueOf(s)
  294. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  295. return errors.New("dest should be a struct's pointer")
  296. }
  297. vvvs[i] = vv.Elem()
  298. }
  299. cols, err := rs.Columns()
  300. if err != nil {
  301. return err
  302. }
  303. newDest := make([]interface{}, len(cols))
  304. var i = 0
  305. for _, vvv := range vvvs {
  306. for j := 0; j < vvv.NumField(); j++ {
  307. newDest[i] = vvv.Field(j).Addr().Interface()
  308. i = i + 1
  309. }
  310. }
  311. return rs.Rows.Scan(newDest...)
  312. }
  313. type EmptyScanner struct {
  314. }
  315. func (EmptyScanner) Scan(src interface{}) error {
  316. return nil
  317. }
  318. var (
  319. fieldCache = make(map[reflect.Type]map[string]int)
  320. fieldCacheMutex sync.RWMutex
  321. )
  322. func fieldByName(v reflect.Value, name string) reflect.Value {
  323. t := v.Type()
  324. fieldCacheMutex.RLock()
  325. cache, ok := fieldCache[t]
  326. fieldCacheMutex.RUnlock()
  327. if !ok {
  328. cache = make(map[string]int)
  329. for i := 0; i < v.NumField(); i++ {
  330. cache[t.Field(i).Name] = i
  331. }
  332. fieldCacheMutex.Lock()
  333. fieldCache[t] = cache
  334. fieldCacheMutex.Unlock()
  335. }
  336. if i, ok := cache[name]; ok {
  337. return v.Field(i)
  338. }
  339. return reflect.Zero(t)
  340. }
  341. // scan data to a struct's pointer according field name
  342. func (rs *Rows) ScanStructByName(dest interface{}) error {
  343. vv := reflect.ValueOf(dest)
  344. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  345. return errors.New("dest should be a struct's pointer")
  346. }
  347. cols, err := rs.Columns()
  348. if err != nil {
  349. return err
  350. }
  351. newDest := make([]interface{}, len(cols))
  352. var v EmptyScanner
  353. for j, name := range cols {
  354. f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
  355. if f.IsValid() {
  356. newDest[j] = f.Addr().Interface()
  357. } else {
  358. newDest[j] = &v
  359. }
  360. }
  361. return rs.Rows.Scan(newDest...)
  362. }
  363. type cacheStruct struct {
  364. value reflect.Value
  365. idx int
  366. }
  367. var (
  368. reflectCache = make(map[reflect.Type]*cacheStruct)
  369. reflectCacheMutex sync.RWMutex
  370. )
  371. func ReflectNew(typ reflect.Type) reflect.Value {
  372. reflectCacheMutex.RLock()
  373. cs, ok := reflectCache[typ]
  374. reflectCacheMutex.RUnlock()
  375. const newSize = 200
  376. if !ok || cs.idx+1 > newSize-1 {
  377. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
  378. reflectCacheMutex.Lock()
  379. reflectCache[typ] = cs
  380. reflectCacheMutex.Unlock()
  381. } else {
  382. reflectCacheMutex.Lock()
  383. cs.idx = cs.idx + 1
  384. reflectCacheMutex.Unlock()
  385. }
  386. return cs.value.Index(cs.idx).Addr()
  387. }
  388. // scan data to a slice's pointer, slice's length should equal to columns' number
  389. func (rs *Rows) ScanSlice(dest interface{}) error {
  390. vv := reflect.ValueOf(dest)
  391. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
  392. return errors.New("dest should be a slice's pointer")
  393. }
  394. vvv := vv.Elem()
  395. cols, err := rs.Columns()
  396. if err != nil {
  397. return err
  398. }
  399. newDest := make([]interface{}, len(cols))
  400. for j := 0; j < len(cols); j++ {
  401. if j >= vvv.Len() {
  402. newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
  403. } else {
  404. newDest[j] = vvv.Index(j).Addr().Interface()
  405. }
  406. }
  407. err = rs.Rows.Scan(newDest...)
  408. if err != nil {
  409. return err
  410. }
  411. srcLen := vvv.Len()
  412. for i := srcLen; i < len(cols); i++ {
  413. vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
  414. }
  415. return nil
  416. }
  417. // scan data to a map's pointer
  418. func (rs *Rows) ScanMap(dest interface{}) error {
  419. vv := reflect.ValueOf(dest)
  420. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  421. return errors.New("dest should be a map's pointer")
  422. }
  423. cols, err := rs.Columns()
  424. if err != nil {
  425. return err
  426. }
  427. newDest := make([]interface{}, len(cols))
  428. vvv := vv.Elem()
  429. for i, _ := range cols {
  430. newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
  431. //v := reflect.New(vvv.Type().Elem())
  432. //newDest[i] = v.Interface()
  433. }
  434. err = rs.Rows.Scan(newDest...)
  435. if err != nil {
  436. return err
  437. }
  438. for i, name := range cols {
  439. vname := reflect.ValueOf(name)
  440. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  441. }
  442. return nil
  443. }
  444. /*func (rs *Rows) ScanMap(dest interface{}) error {
  445. vv := reflect.ValueOf(dest)
  446. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  447. return errors.New("dest should be a map's pointer")
  448. }
  449. cols, err := rs.Columns()
  450. if err != nil {
  451. return err
  452. }
  453. newDest := make([]interface{}, len(cols))
  454. err = rs.ScanSlice(newDest)
  455. if err != nil {
  456. return err
  457. }
  458. vvv := vv.Elem()
  459. for i, name := range cols {
  460. vname := reflect.ValueOf(name)
  461. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  462. }
  463. return nil
  464. }*/
  465. type Tx struct {
  466. *sql.Tx
  467. Mapper IMapper
  468. }
  469. func (db *DB) Begin() (*Tx, error) {
  470. tx, err := db.DB.Begin()
  471. if err != nil {
  472. return nil, err
  473. }
  474. return &Tx{tx, db.Mapper}, nil
  475. }
  476. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  477. names := make(map[string]int)
  478. var i int
  479. query = re.ReplaceAllStringFunc(query, func(src string) string {
  480. names[src[1:]] = i
  481. i += 1
  482. return "?"
  483. })
  484. stmt, err := tx.Tx.Prepare(query)
  485. if err != nil {
  486. return nil, err
  487. }
  488. return &Stmt{stmt, tx.Mapper, names}, nil
  489. }
  490. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  491. // TODO:
  492. return stmt
  493. }
  494. func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
  495. query, args, err := MapToSlice(query, mp)
  496. if err != nil {
  497. return nil, err
  498. }
  499. return tx.Tx.Exec(query, args...)
  500. }
  501. func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
  502. query, args, err := StructToSlice(query, st)
  503. if err != nil {
  504. return nil, err
  505. }
  506. return tx.Tx.Exec(query, args...)
  507. }
  508. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
  509. rows, err := tx.Tx.Query(query, args...)
  510. if err != nil {
  511. return nil, err
  512. }
  513. return &Rows{rows, tx.Mapper}, nil
  514. }
  515. func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
  516. query, args, err := MapToSlice(query, mp)
  517. if err != nil {
  518. return nil, err
  519. }
  520. return tx.Query(query, args...)
  521. }
  522. func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
  523. query, args, err := StructToSlice(query, st)
  524. if err != nil {
  525. return nil, err
  526. }
  527. return tx.Query(query, args...)
  528. }
  529. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  530. rows, err := tx.Query(query, args...)
  531. return &Row{rows, err}
  532. }
  533. func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
  534. query, args, err := MapToSlice(query, mp)
  535. if err != nil {
  536. return &Row{nil, err}
  537. }
  538. return tx.QueryRow(query, args...)
  539. }
  540. func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
  541. query, args, err := StructToSlice(query, st)
  542. if err != nil {
  543. return &Row{nil, err}
  544. }
  545. return tx.QueryRow(query, args...)
  546. }