Parcourir la source

bus: support multiple dispatch in one transaction

this makes it possible to run multiple DispatchCtx
in one transaction. The TransactionManager will
start/end the transaction and pass the dbsession
in the context.Context variable
bergquist il y a 7 ans
Parent
commit
8143610024

+ 50 - 0
pkg/bus/bus.go

@@ -12,21 +12,51 @@ type Msg interface{}
 
 var ErrHandlerNotFound = errors.New("handler not found")
 
+type TransactionManager interface {
+	Begin(ctx context.Context) (context.Context, error)
+	End(ctx context.Context, err error) error
+}
+
 type Bus interface {
 	Dispatch(msg Msg) error
 	DispatchCtx(ctx context.Context, msg Msg) error
 	Publish(msg Msg) error
 
+	// InTransaction starts a transaction and store it in the context.
+	// The caller can then pass a function with multiple DispatchCtx calls that
+	// all will be executed in the same transaction. InTransaction will rollback if the
+	// callback returns an error.s
+	InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
+
 	AddHandler(handler HandlerFunc)
 	AddCtxHandler(handler HandlerFunc)
 	AddEventListener(handler HandlerFunc)
 	AddWildcardListener(handler HandlerFunc)
+
+	// SetTransactionManager allows the user to replace the internal
+	// noop TransactionManager that is responsible for manageing
+	// transactions in `InTransaction`
+	SetTransactionManager(tm TransactionManager)
+}
+
+func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	ctxWithTran, err := b.transactionManager.Begin(ctx)
+	if err != nil {
+		return err
+	}
+
+	err = fn(ctxWithTran)
+	b.transactionManager.End(ctxWithTran, err)
+
+	return err
 }
 
 type InProcBus struct {
 	handlers          map[string]HandlerFunc
 	listeners         map[string][]HandlerFunc
 	wildcardListeners []HandlerFunc
+
+	transactionManager TransactionManager
 }
 
 // temp stuff, not sure how to handle bus instance, and init yet
@@ -37,6 +67,9 @@ func New() Bus {
 	bus.handlers = make(map[string]HandlerFunc)
 	bus.listeners = make(map[string][]HandlerFunc)
 	bus.wildcardListeners = make([]HandlerFunc, 0)
+
+	bus.transactionManager = &NoopTransactionManager{}
+
 	return bus
 }
 
@@ -45,6 +78,14 @@ func GetBus() Bus {
 	return globalBus
 }
 
+func SetTransactionManager(tm TransactionManager) {
+	globalBus.SetTransactionManager(tm)
+}
+
+func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
+	b.transactionManager = tm
+}
+
 func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
 	var msgName = reflect.TypeOf(msg).Elem().Name()
 
@@ -167,6 +208,15 @@ func Publish(msg Msg) error {
 	return globalBus.Publish(msg)
 }
 
+func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return globalBus.InTransaction(ctx, fn)
+}
+
 func ClearBusHandlers() {
 	globalBus = New()
 }
+
+type NoopTransactionManager struct{}
+
+func (*NoopTransactionManager) Begin(ctx context.Context) (context.Context, error) { return ctx, nil }
+func (*NoopTransactionManager) End(ctx context.Context, err error) error           { return err }

+ 1 - 0
pkg/services/alerting/notifiers/base.go

@@ -3,6 +3,7 @@ package notifiers
 import (
 	"github.com/grafana/grafana/pkg/components/simplejson"
 	m "github.com/grafana/grafana/pkg/models"
+
 	"github.com/grafana/grafana/pkg/services/alerting"
 )
 

+ 27 - 1
pkg/services/sqlstore/shared.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"reflect"
 	"time"
 
@@ -29,10 +30,35 @@ func inTransaction(callback dbTransactionFunc) error {
 	return inTransactionWithRetry(callback, 0)
 }
 
+func startSession(ctx context.Context) *DBSession {
+	value := ctx.Value(ContextSessionName)
+	var sess *xorm.Session
+	sess, ok := value.(*xorm.Session)
+
+	if !ok {
+		return newSession()
+	}
+
+	old := newSession()
+	old.Session = sess
+
+	return old
+}
+
+func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
+	sess := startSession(ctx)
+
+	return callback(sess)
+}
+
 func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
+	return inTransactionWithRetryCtx(context.Background(), callback, retry)
+}
+
+func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error {
 	var err error
 
-	sess := newSession()
+	sess := startSession(ctx)
 	defer sess.Close()
 
 	if err = sess.Begin(); err != nil {

+ 48 - 1
pkg/services/sqlstore/sqlstore.go

@@ -1,6 +1,8 @@
 package sqlstore
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"net/url"
 	"os"
@@ -35,6 +37,8 @@ var (
 	sqlog log.Logger = log.New("sqlstore")
 )
 
+const ContextSessionName = "db-session"
+
 func init() {
 	registry.Register(&registry.Descriptor{
 		Name:         "SqlStore",
@@ -45,6 +49,7 @@ func init() {
 
 type SqlStore struct {
 	Cfg *setting.Cfg `inject:""`
+	Bus bus.Bus      `inject:""`
 
 	dbCfg           DatabaseConfig
 	engine          *xorm.Engine
@@ -77,6 +82,10 @@ func (ss *SqlStore) Init() error {
 	// Init repo instances
 	annotations.SetRepository(&SqlAnnotationRepo{})
 
+	ss.Bus.SetTransactionManager(&SQLTransactionManager{
+		engine: ss.engine,
+	})
+
 	// ensure admin user
 	if ss.skipEnsureAdmin {
 		return nil
@@ -85,10 +94,47 @@ func (ss *SqlStore) Init() error {
 	return ss.ensureAdminUser()
 }
 
+type SQLTransactionManager struct {
+	engine *xorm.Engine
+}
+
+func (stm *SQLTransactionManager) Begin(ctx context.Context) (context.Context, error) {
+	sess := stm.engine.NewSession()
+	err := sess.Begin()
+	if err != nil {
+		return ctx, err
+	}
+
+	withValue := context.WithValue(ctx, ContextSessionName, sess)
+
+	return withValue, nil
+}
+
+func (stm *SQLTransactionManager) End(ctx context.Context, err error) error {
+	value := ctx.Value(ContextSessionName)
+	sess, ok := value.(*xorm.Session)
+	if !ok {
+		return errors.New("context is missing transaction")
+	}
+
+	if err != nil {
+		sess.Rollback()
+		return err
+	}
+
+	defer sess.Close()
+
+	return sess.Commit()
+}
+
 func (ss *SqlStore) ensureAdminUser() error {
 	systemUserCountQuery := m.GetSystemUserCountStatsQuery{}
 
-	if err := bus.Dispatch(&systemUserCountQuery); err != nil {
+	err := bus.InTransaction(context.Background(), func(ctx context.Context) error {
+		return bus.DispatchCtx(ctx, &systemUserCountQuery)
+	})
+
+	if err != nil {
 		return fmt.Errorf("Could not determine if admin user exists: %v", err)
 	}
 
@@ -240,6 +286,7 @@ func (ss *SqlStore) readConfig() {
 func InitTestDB(t *testing.T) *SqlStore {
 	sqlstore := &SqlStore{}
 	sqlstore.skipEnsureAdmin = true
+	sqlstore.Bus = bus.New()
 
 	dbType := migrator.SQLITE
 

+ 18 - 0
pkg/services/sqlstore/stats.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"time"
 
 	"github.com/grafana/grafana/pkg/bus"
@@ -13,6 +14,7 @@ func init() {
 	bus.AddHandler("sql", GetDataSourceAccessStats)
 	bus.AddHandler("sql", GetAdminStats)
 	bus.AddHandler("sql", GetSystemUserCountStats)
+	bus.AddCtxHandler("sql", GetSystemUserCountStatsCtx)
 }
 
 var activeUserTimeLimit = time.Hour * 24 * 30
@@ -133,6 +135,22 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error {
 	return err
 }
 
+func GetSystemUserCountStatsCtx(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error {
+	return withDbSession(ctx, func(sess *DBSession) error {
+
+		var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
+		var stats m.SystemUserCountStats
+		_, err := sess.SQL(rawSql).Get(&stats)
+		if err != nil {
+			return err
+		}
+
+		query.Result = &stats
+
+		return err
+	})
+}
+
 func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error {
 	var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
 	var stats m.SystemUserCountStats