session.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package sqlstore
  2. import (
  3. "context"
  4. "reflect"
  5. "github.com/go-xorm/xorm"
  6. )
  7. type DBSession struct {
  8. *xorm.Session
  9. events []interface{}
  10. }
  11. type dbTransactionFunc func(sess *DBSession) error
  12. func (sess *DBSession) publishAfterCommit(msg interface{}) {
  13. sess.events = append(sess.events, msg)
  14. }
  15. // NewSession returns a new DBSession
  16. func (ss *SqlStore) NewSession() *DBSession {
  17. return &DBSession{Session: ss.engine.NewSession()}
  18. }
  19. func newSession() *DBSession {
  20. return &DBSession{Session: x.NewSession()}
  21. }
  22. func startSession(ctx context.Context, engine *xorm.Engine, beginTran bool) (*DBSession, error) {
  23. value := ctx.Value(ContextSessionName)
  24. var sess *DBSession
  25. sess, ok := value.(*DBSession)
  26. if ok {
  27. return sess, nil
  28. }
  29. newSess := &DBSession{Session: engine.NewSession()}
  30. if beginTran {
  31. err := newSess.Begin()
  32. if err != nil {
  33. return nil, err
  34. }
  35. }
  36. return newSess, nil
  37. }
  38. // WithDbSession calls the callback with an session attached to the context.
  39. func (ss *SqlStore) WithDbSession(ctx context.Context, callback dbTransactionFunc) error {
  40. sess, err := startSession(ctx, ss.engine, false)
  41. if err != nil {
  42. return err
  43. }
  44. return callback(sess)
  45. }
  46. func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
  47. sess, err := startSession(ctx, x, false)
  48. if err != nil {
  49. return err
  50. }
  51. return callback(sess)
  52. }
  53. func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
  54. table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
  55. dialect.PreInsertId(table, sess.Session)
  56. id, err := sess.Session.InsertOne(bean)
  57. dialect.PostInsertId(table, sess.Session)
  58. return id, err
  59. }
  60. func getTypeName(bean interface{}) (res string) {
  61. t := reflect.TypeOf(bean)
  62. for t.Kind() == reflect.Ptr {
  63. t = t.Elem()
  64. }
  65. return t.Name()
  66. }