session.go 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package sqlstore
  2. import (
  3. "context"
  4. "reflect"
  5. "github.com/go-xorm/xorm"
  6. )
  7. type DBSession struct {
  8. *xorm.Session
  9. events []interface{}
  10. }
  11. type dbTransactionFunc func(sess *DBSession) error
  12. func (sess *DBSession) publishAfterCommit(msg interface{}) {
  13. sess.events = append(sess.events, msg)
  14. }
  15. func newSession() *DBSession {
  16. return &DBSession{Session: x.NewSession()}
  17. }
  18. func startSession(ctx context.Context, engine *xorm.Engine, beginTran bool) (*DBSession, error) {
  19. value := ctx.Value(ContextSessionName)
  20. var sess *DBSession
  21. sess, ok := value.(*DBSession)
  22. if !ok {
  23. newSess := &DBSession{Session: engine.NewSession()}
  24. if beginTran {
  25. err := newSess.Begin()
  26. if err != nil {
  27. return nil, err
  28. }
  29. }
  30. return newSess, nil
  31. }
  32. return sess, nil
  33. }
  34. func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
  35. sess, err := startSession(ctx, x, false)
  36. if err != nil {
  37. return err
  38. }
  39. return callback(sess)
  40. }
  41. func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
  42. table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
  43. dialect.PreInsertId(table, sess.Session)
  44. id, err := sess.Session.InsertOne(bean)
  45. dialect.PostInsertId(table, sess.Session)
  46. return id, err
  47. }
  48. func getTypeName(bean interface{}) (res string) {
  49. t := reflect.TypeOf(bean)
  50. for t.Kind() == reflect.Ptr {
  51. t = t.Elem()
  52. }
  53. return t.Name()
  54. }