Browse Source

Merge pull request #12203 from bergquist/bus_multi_dispatch

bus: support multiple dispatch in one transaction
Carl Bergquist 7 years ago
parent
commit
d6f4313c2f

+ 64 - 12
pkg/bus/bus.go

@@ -12,21 +12,42 @@ type Msg interface{}
 
 var ErrHandlerNotFound = errors.New("handler not found")
 
+type TransactionManager interface {
+	InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
+}
+
 type Bus interface {
 	Dispatch(msg Msg) error
 	DispatchCtx(ctx context.Context, msg Msg) error
 	Publish(msg Msg) error
 
+	// InTransaction starts a transaction and store it in the context.
+	// The caller can then pass a function with multiple DispatchCtx calls that
+	// all will be executed in the same transaction. InTransaction will rollback if the
+	// callback returns an error.
+	InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
+
 	AddHandler(handler HandlerFunc)
-	AddCtxHandler(handler HandlerFunc)
+	AddHandlerCtx(handler HandlerFunc)
 	AddEventListener(handler HandlerFunc)
 	AddWildcardListener(handler HandlerFunc)
+
+	// SetTransactionManager allows the user to replace the internal
+	// noop TransactionManager that is responsible for manageing
+	// transactions in `InTransaction`
+	SetTransactionManager(tm TransactionManager)
+}
+
+func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return b.txMng.InTransaction(ctx, fn)
 }
 
 type InProcBus struct {
 	handlers          map[string]HandlerFunc
+	handlersWithCtx   map[string]HandlerFunc
 	listeners         map[string][]HandlerFunc
 	wildcardListeners []HandlerFunc
+	txMng             TransactionManager
 }
 
 // temp stuff, not sure how to handle bus instance, and init yet
@@ -35,8 +56,11 @@ var globalBus = New()
 func New() Bus {
 	bus := &InProcBus{}
 	bus.handlers = make(map[string]HandlerFunc)
+	bus.handlersWithCtx = make(map[string]HandlerFunc)
 	bus.listeners = make(map[string][]HandlerFunc)
 	bus.wildcardListeners = make([]HandlerFunc, 0)
+	bus.txMng = &noopTransactionManager{}
+
 	return bus
 }
 
@@ -45,17 +69,21 @@ func GetBus() Bus {
 	return globalBus
 }
 
+func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
+	b.txMng = tm
+}
+
 func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
 	var msgName = reflect.TypeOf(msg).Elem().Name()
 
-	var handler = b.handlers[msgName]
+	var handler = b.handlersWithCtx[msgName]
 	if handler == nil {
 		return ErrHandlerNotFound
 	}
 
-	var params = make([]reflect.Value, 2)
-	params[0] = reflect.ValueOf(ctx)
-	params[1] = reflect.ValueOf(msg)
+	var params = []reflect.Value{}
+	params = append(params, reflect.ValueOf(ctx))
+	params = append(params, reflect.ValueOf(msg))
 
 	ret := reflect.ValueOf(handler).Call(params)
 	err := ret[0].Interface()
@@ -68,13 +96,23 @@ func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
 func (b *InProcBus) Dispatch(msg Msg) error {
 	var msgName = reflect.TypeOf(msg).Elem().Name()
 
-	var handler = b.handlers[msgName]
+	var handler = b.handlersWithCtx[msgName]
+	withCtx := true
+
+	if handler == nil {
+		withCtx = false
+		handler = b.handlers[msgName]
+	}
+
 	if handler == nil {
 		return ErrHandlerNotFound
 	}
 
-	var params = make([]reflect.Value, 1)
-	params[0] = reflect.ValueOf(msg)
+	var params = []reflect.Value{}
+	if withCtx {
+		params = append(params, reflect.ValueOf(context.Background()))
+	}
+	params = append(params, reflect.ValueOf(msg))
 
 	ret := reflect.ValueOf(handler).Call(params)
 	err := ret[0].Interface()
@@ -120,10 +158,10 @@ func (b *InProcBus) AddHandler(handler HandlerFunc) {
 	b.handlers[queryTypeName] = handler
 }
 
-func (b *InProcBus) AddCtxHandler(handler HandlerFunc) {
+func (b *InProcBus) AddHandlerCtx(handler HandlerFunc) {
 	handlerType := reflect.TypeOf(handler)
 	queryTypeName := handlerType.In(1).Elem().Name()
-	b.handlers[queryTypeName] = handler
+	b.handlersWithCtx[queryTypeName] = handler
 }
 
 func (b *InProcBus) AddEventListener(handler HandlerFunc) {
@@ -142,8 +180,8 @@ func AddHandler(implName string, handler HandlerFunc) {
 }
 
 // Package level functions
-func AddCtxHandler(implName string, handler HandlerFunc) {
-	globalBus.AddCtxHandler(handler)
+func AddHandlerCtx(implName string, handler HandlerFunc) {
+	globalBus.AddHandlerCtx(handler)
 }
 
 // Package level functions
@@ -167,6 +205,20 @@ func Publish(msg Msg) error {
 	return globalBus.Publish(msg)
 }
 
+// InTransaction starts a transaction and store it in the context.
+// The caller can then pass a function with multiple DispatchCtx calls that
+// all will be executed in the same transaction. InTransaction will rollback if the
+// callback returns an error.
+func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return globalBus.InTransaction(ctx, fn)
+}
+
 func ClearBusHandlers() {
 	globalBus = New()
 }
+
+type noopTransactionManager struct{}
+
+func (*noopTransactionManager) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return fn(ctx)
+}

+ 51 - 8
pkg/bus/bus_test.go

@@ -1,24 +1,67 @@
 package bus
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"testing"
 )
 
-type TestQuery struct {
+type testQuery struct {
 	Id   int64
 	Resp string
 }
 
+func TestDispatchCtxCanUseNormalHandlers(t *testing.T) {
+	bus := New()
+
+	handlerWithCtxCallCount := 0
+	handlerCallCount := 0
+
+	handlerWithCtx := func(ctx context.Context, query *testQuery) error {
+		handlerWithCtxCallCount++
+		return nil
+	}
+
+	handler := func(query *testQuery) error {
+		handlerCallCount++
+		return nil
+	}
+
+	err := bus.DispatchCtx(context.Background(), &testQuery{})
+	if err != ErrHandlerNotFound {
+		t.Errorf("expected bus to return HandlerNotFound is no handler is registered")
+	}
+
+	bus.AddHandler(handler)
+
+	t.Run("when a normal handler is registered", func(t *testing.T) {
+		bus.Dispatch(&testQuery{})
+
+		if handlerCallCount != 1 {
+			t.Errorf("Expected normal handler to be called 1 time. was called %d", handlerCallCount)
+		}
+
+		t.Run("when a ctx handler is registered", func(t *testing.T) {
+			bus.AddHandlerCtx(handlerWithCtx)
+			bus.Dispatch(&testQuery{})
+
+			if handlerWithCtxCallCount != 1 {
+				t.Errorf("Expected ctx handler to be called 1 time. was called %d", handlerWithCtxCallCount)
+			}
+		})
+	})
+
+}
+
 func TestQueryHandlerReturnsError(t *testing.T) {
 	bus := New()
 
-	bus.AddHandler(func(query *TestQuery) error {
+	bus.AddHandler(func(query *testQuery) error {
 		return errors.New("handler error")
 	})
 
-	err := bus.Dispatch(&TestQuery{})
+	err := bus.Dispatch(&testQuery{})
 
 	if err == nil {
 		t.Fatal("Send query failed " + err.Error())
@@ -30,12 +73,12 @@ func TestQueryHandlerReturnsError(t *testing.T) {
 func TestQueryHandlerReturn(t *testing.T) {
 	bus := New()
 
-	bus.AddHandler(func(q *TestQuery) error {
+	bus.AddHandler(func(q *testQuery) error {
 		q.Resp = "hello from handler"
 		return nil
 	})
 
-	query := &TestQuery{}
+	query := &testQuery{}
 	err := bus.Dispatch(query)
 
 	if err != nil {
@@ -49,17 +92,17 @@ func TestEventListeners(t *testing.T) {
 	bus := New()
 	count := 0
 
-	bus.AddEventListener(func(query *TestQuery) error {
+	bus.AddEventListener(func(query *testQuery) error {
 		count += 1
 		return nil
 	})
 
-	bus.AddEventListener(func(query *TestQuery) error {
+	bus.AddEventListener(func(query *testQuery) error {
 		count += 10
 		return nil
 	})
 
-	err := bus.Publish(&TestQuery{})
+	err := bus.Publish(&testQuery{})
 
 	if err != nil {
 		t.Fatal("Publish event failed " + err.Error())

+ 1 - 0
pkg/services/alerting/notifiers/base.go

@@ -3,6 +3,7 @@ package notifiers
 import (
 	"github.com/grafana/grafana/pkg/components/simplejson"
 	m "github.com/grafana/grafana/pkg/models"
+
 	"github.com/grafana/grafana/pkg/services/alerting"
 )
 

+ 2 - 2
pkg/services/notifications/notifications.go

@@ -45,8 +45,8 @@ func (ns *NotificationService) Init() error {
 	ns.Bus.AddHandler(ns.validateResetPasswordCode)
 	ns.Bus.AddHandler(ns.sendEmailCommandHandler)
 
-	ns.Bus.AddCtxHandler(ns.sendEmailCommandHandlerSync)
-	ns.Bus.AddCtxHandler(ns.SendWebhookSync)
+	ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync)
+	ns.Bus.AddHandlerCtx(ns.SendWebhookSync)
 
 	ns.Bus.AddEventListener(ns.signUpStartedHandler)
 	ns.Bus.AddEventListener(ns.signUpCompletedHandler)

+ 4 - 3
pkg/services/sqlstore/apikey.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"time"
 
 	"github.com/grafana/grafana/pkg/bus"
@@ -11,7 +12,7 @@ func init() {
 	bus.AddHandler("sql", GetApiKeys)
 	bus.AddHandler("sql", GetApiKeyById)
 	bus.AddHandler("sql", GetApiKeyByName)
-	bus.AddHandler("sql", DeleteApiKey)
+	bus.AddHandlerCtx("sql", DeleteApiKeyCtx)
 	bus.AddHandler("sql", AddApiKey)
 }
 
@@ -22,8 +23,8 @@ func GetApiKeys(query *m.GetApiKeysQuery) error {
 	return sess.Find(&query.Result)
 }
 
-func DeleteApiKey(cmd *m.DeleteApiKeyCommand) error {
-	return inTransaction(func(sess *DBSession) error {
+func DeleteApiKeyCtx(ctx context.Context, cmd *m.DeleteApiKeyCommand) error {
+	return withDbSession(ctx, func(sess *DBSession) error {
 		var rawSql = "DELETE FROM api_key WHERE id=? and org_id=?"
 		_, err := sess.Exec(rawSql, cmd.Id, cmd.OrgId)
 		return err

+ 2 - 1
pkg/services/sqlstore/dashboard_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"fmt"
 	"testing"
 	"time"
@@ -389,7 +390,7 @@ func createUser(name string, role string, isAdmin bool) m.User {
 	setting.AutoAssignOrgRole = role
 
 	currentUserCmd := m.CreateUserCommand{Login: name, Email: name + "@test.com", Name: "a " + name, IsAdmin: isAdmin}
-	err := CreateUser(&currentUserCmd)
+	err := CreateUser(context.Background(), &currentUserCmd)
 	So(err, ShouldBeNil)
 
 	q1 := m.GetUserOrgListQuery{UserId: currentUserCmd.Result.Id}

+ 6 - 5
pkg/services/sqlstore/org_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"testing"
 	"time"
 
@@ -22,9 +23,9 @@ func TestAccountDataAccess(t *testing.T) {
 				ac1cmd := m.CreateUserCommand{Login: "ac1", Email: "ac1@test.com", Name: "ac1 name"}
 				ac2cmd := m.CreateUserCommand{Login: "ac2", Email: "ac2@test.com", Name: "ac2 name"}
 
-				err := CreateUser(&ac1cmd)
+				err := CreateUser(context.Background(), &ac1cmd)
 				So(err, ShouldBeNil)
-				err = CreateUser(&ac2cmd)
+				err = CreateUser(context.Background(), &ac2cmd)
 				So(err, ShouldBeNil)
 
 				q1 := m.GetUserOrgListQuery{UserId: ac1cmd.Result.Id}
@@ -43,8 +44,8 @@ func TestAccountDataAccess(t *testing.T) {
 			ac1cmd := m.CreateUserCommand{Login: "ac1", Email: "ac1@test.com", Name: "ac1 name"}
 			ac2cmd := m.CreateUserCommand{Login: "ac2", Email: "ac2@test.com", Name: "ac2 name", IsAdmin: true}
 
-			err := CreateUser(&ac1cmd)
-			err = CreateUser(&ac2cmd)
+			err := CreateUser(context.Background(), &ac1cmd)
+			err = CreateUser(context.Background(), &ac2cmd)
 			So(err, ShouldBeNil)
 
 			ac1 := ac1cmd.Result
@@ -182,7 +183,7 @@ func TestAccountDataAccess(t *testing.T) {
 
 				Convey("Given an org user with dashboard permissions", func() {
 					ac3cmd := m.CreateUserCommand{Login: "ac3", Email: "ac3@test.com", Name: "ac3 name", IsAdmin: false}
-					err := CreateUser(&ac3cmd)
+					err := CreateUser(context.Background(), &ac3cmd)
 					So(err, ShouldBeNil)
 					ac3 := ac3cmd.Result
 

+ 71 - 0
pkg/services/sqlstore/session.go

@@ -0,0 +1,71 @@
+package sqlstore
+
+import (
+	"context"
+	"reflect"
+
+	"github.com/go-xorm/xorm"
+)
+
+type DBSession struct {
+	*xorm.Session
+	events []interface{}
+}
+
+type dbTransactionFunc func(sess *DBSession) error
+
+func (sess *DBSession) publishAfterCommit(msg interface{}) {
+	sess.events = append(sess.events, msg)
+}
+
+func newSession() *DBSession {
+	return &DBSession{Session: x.NewSession()}
+}
+
+func startSession(ctx context.Context, engine *xorm.Engine, beginTran bool) (*DBSession, error) {
+	value := ctx.Value(ContextSessionName)
+	var sess *DBSession
+	sess, ok := value.(*DBSession)
+
+	if !ok {
+		newSess := &DBSession{Session: engine.NewSession()}
+		if beginTran {
+			err := newSess.Begin()
+			if err != nil {
+				return nil, err
+			}
+		}
+		return newSess, nil
+	}
+
+	return sess, nil
+}
+
+func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
+	sess, err := startSession(ctx, x, false)
+	if err != nil {
+		return err
+	}
+
+	return callback(sess)
+}
+
+func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
+	table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
+
+	dialect.PreInsertId(table, sess.Session)
+
+	id, err := sess.Session.InsertOne(bean)
+
+	dialect.PostInsertId(table, sess.Session)
+
+	return id, err
+}
+
+func getTypeName(bean interface{}) (res string) {
+	t := reflect.TypeOf(bean)
+	for t.Kind() == reflect.Ptr {
+		t = t.Elem()
+	}
+	return t.Name()
+}

+ 0 - 90
pkg/services/sqlstore/shared.go

@@ -1,90 +0,0 @@
-package sqlstore
-
-import (
-	"reflect"
-	"time"
-
-	"github.com/go-xorm/xorm"
-	"github.com/grafana/grafana/pkg/bus"
-	"github.com/grafana/grafana/pkg/log"
-	sqlite3 "github.com/mattn/go-sqlite3"
-)
-
-type DBSession struct {
-	*xorm.Session
-	events []interface{}
-}
-
-type dbTransactionFunc func(sess *DBSession) error
-
-func (sess *DBSession) publishAfterCommit(msg interface{}) {
-	sess.events = append(sess.events, msg)
-}
-
-func newSession() *DBSession {
-	return &DBSession{Session: x.NewSession()}
-}
-
-func inTransaction(callback dbTransactionFunc) error {
-	return inTransactionWithRetry(callback, 0)
-}
-
-func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
-	var err error
-
-	sess := newSession()
-	defer sess.Close()
-
-	if err = sess.Begin(); err != nil {
-		return err
-	}
-
-	err = callback(sess)
-
-	// special handling of database locked errors for sqlite, then we can retry 3 times
-	if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
-		if sqlError.Code == sqlite3.ErrLocked {
-			sess.Rollback()
-			time.Sleep(time.Millisecond * time.Duration(10))
-			sqlog.Info("Database table locked, sleeping then retrying", "retry", retry)
-			return inTransactionWithRetry(callback, retry+1)
-		}
-	}
-
-	if err != nil {
-		sess.Rollback()
-		return err
-	} else if err = sess.Commit(); err != nil {
-		return err
-	}
-
-	if len(sess.events) > 0 {
-		for _, e := range sess.events {
-			if err = bus.Publish(e); err != nil {
-				log.Error(3, "Failed to publish event after commit", err)
-			}
-		}
-	}
-
-	return nil
-}
-
-func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
-	table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
-
-	dialect.PreInsertId(table, sess.Session)
-
-	id, err := sess.Session.InsertOne(bean)
-
-	dialect.PostInsertId(table, sess.Session)
-
-	return id, err
-}
-
-func getTypeName(bean interface{}) (res string) {
-	t := reflect.TypeOf(bean)
-	for t.Kind() == reflect.Ptr {
-		t = t.Elem()
-	}
-	return t.Name()
-}

+ 32 - 18
pkg/services/sqlstore/sqlstore.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"fmt"
 	"net/url"
 	"os"
@@ -22,10 +23,10 @@ import (
 
 	"github.com/go-sql-driver/mysql"
 	"github.com/go-xorm/xorm"
-	_ "github.com/lib/pq"
-	_ "github.com/mattn/go-sqlite3"
 
 	_ "github.com/grafana/grafana/pkg/tsdb/mssql"
+	_ "github.com/lib/pq"
+	_ "github.com/mattn/go-sqlite3"
 )
 
 var (
@@ -35,6 +36,8 @@ var (
 	sqlog log.Logger = log.New("sqlstore")
 )
 
+const ContextSessionName = "db-session"
+
 func init() {
 	registry.Register(&registry.Descriptor{
 		Name:         "SqlStore",
@@ -45,6 +48,7 @@ func init() {
 
 type SqlStore struct {
 	Cfg *setting.Cfg `inject:""`
+	Bus bus.Bus      `inject:""`
 
 	dbCfg           DatabaseConfig
 	engine          *xorm.Engine
@@ -77,6 +81,8 @@ func (ss *SqlStore) Init() error {
 	// Init repo instances
 	annotations.SetRepository(&SqlAnnotationRepo{})
 
+	ss.Bus.SetTransactionManager(ss)
+
 	// ensure admin user
 	if ss.skipEnsureAdmin {
 		return nil
@@ -88,27 +94,33 @@ func (ss *SqlStore) Init() error {
 func (ss *SqlStore) ensureAdminUser() error {
 	systemUserCountQuery := m.GetSystemUserCountStatsQuery{}
 
-	if err := bus.Dispatch(&systemUserCountQuery); err != nil {
-		return fmt.Errorf("Could not determine if admin user exists: %v", err)
-	}
+	err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
 
-	if systemUserCountQuery.Result.Count > 0 {
-		return nil
-	}
+		err := bus.DispatchCtx(ctx, &systemUserCountQuery)
+		if err != nil {
+			return fmt.Errorf("Could not determine if admin user exists: %v", err)
+		}
+
+		if systemUserCountQuery.Result.Count > 0 {
+			return nil
+		}
 
-	cmd := m.CreateUserCommand{}
-	cmd.Login = setting.AdminUser
-	cmd.Email = setting.AdminUser + "@localhost"
-	cmd.Password = setting.AdminPassword
-	cmd.IsAdmin = true
+		cmd := m.CreateUserCommand{}
+		cmd.Login = setting.AdminUser
+		cmd.Email = setting.AdminUser + "@localhost"
+		cmd.Password = setting.AdminPassword
+		cmd.IsAdmin = true
 
-	if err := bus.Dispatch(&cmd); err != nil {
-		return fmt.Errorf("Failed to create admin user: %v", err)
-	}
+		if err := bus.DispatchCtx(ctx, &cmd); err != nil {
+			return fmt.Errorf("Failed to create admin user: %v", err)
+		}
+
+		ss.log.Info("Created default admin", "user", setting.AdminUser)
 
-	ss.log.Info("Created default admin user: %v", setting.AdminUser)
+		return nil
+	})
 
-	return nil
+	return err
 }
 
 func (ss *SqlStore) buildConnectionString() (string, error) {
@@ -238,8 +250,10 @@ func (ss *SqlStore) readConfig() {
 }
 
 func InitTestDB(t *testing.T) *SqlStore {
+	t.Helper()
 	sqlstore := &SqlStore{}
 	sqlstore.skipEnsureAdmin = true
+	sqlstore.Bus = bus.New()
 
 	dbType := migrator.SQLITE
 

+ 14 - 10
pkg/services/sqlstore/stats.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"time"
 
 	"github.com/grafana/grafana/pkg/bus"
@@ -12,7 +13,7 @@ func init() {
 	bus.AddHandler("sql", GetDataSourceStats)
 	bus.AddHandler("sql", GetDataSourceAccessStats)
 	bus.AddHandler("sql", GetAdminStats)
-	bus.AddHandler("sql", GetSystemUserCountStats)
+	bus.AddHandlerCtx("sql", GetSystemUserCountStats)
 }
 
 var activeUserTimeLimit = time.Hour * 24 * 30
@@ -133,15 +134,18 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error {
 	return err
 }
 
-func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error {
-	var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
-	var stats m.SystemUserCountStats
-	_, err := x.SQL(rawSql).Get(&stats)
-	if err != nil {
-		return err
-	}
+func GetSystemUserCountStats(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error {
+	return withDbSession(ctx, func(sess *DBSession) error {
 
-	query.Result = &stats
+		var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
+		var stats m.SystemUserCountStats
+		_, err := sess.SQL(rawSql).Get(&stats)
+		if err != nil {
+			return err
+		}
 
-	return err
+		query.Result = &stats
+
+		return err
+	})
 }

+ 2 - 1
pkg/services/sqlstore/stats_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"testing"
 
 	m "github.com/grafana/grafana/pkg/models"
@@ -20,7 +21,7 @@ func TestStatsDataAccess(t *testing.T) {
 
 		Convey("Get system user count stats should not results in error", func() {
 			query := m.GetSystemUserCountStatsQuery{}
-			err := GetSystemUserCountStats(&query)
+			err := GetSystemUserCountStats(context.Background(), &query)
 			So(err, ShouldBeNil)
 		})
 

+ 2 - 1
pkg/services/sqlstore/team_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"fmt"
 	"testing"
 
@@ -22,7 +23,7 @@ func TestTeamCommandsAndQueries(t *testing.T) {
 					Name:  fmt.Sprint("user", i),
 					Login: fmt.Sprint("loginuser", i),
 				}
-				err := CreateUser(userCmd)
+				err := CreateUser(context.Background(), userCmd)
 				So(err, ShouldBeNil)
 				userIds = append(userIds, userCmd.Result.Id)
 			}

+ 106 - 0
pkg/services/sqlstore/transactions.go

@@ -0,0 +1,106 @@
+package sqlstore
+
+import (
+	"context"
+	"time"
+
+	"github.com/grafana/grafana/pkg/bus"
+	"github.com/grafana/grafana/pkg/log"
+	sqlite3 "github.com/mattn/go-sqlite3"
+)
+
+func (ss *SqlStore) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
+	return ss.inTransactionWithRetry(ctx, fn, 0)
+}
+
+func (ss *SqlStore) inTransactionWithRetry(ctx context.Context, fn func(ctx context.Context) error, retry int) error {
+	sess, err := startSession(ctx, ss.engine, true)
+	if err != nil {
+		return err
+	}
+
+	defer sess.Close()
+
+	withValue := context.WithValue(ctx, ContextSessionName, sess)
+
+	err = fn(withValue)
+
+	// special handling of database locked errors for sqlite, then we can retry 3 times
+	if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
+		if sqlError.Code == sqlite3.ErrLocked {
+			sess.Rollback()
+			time.Sleep(time.Millisecond * time.Duration(10))
+			ss.log.Info("Database table locked, sleeping then retrying", "retry", retry)
+			return ss.inTransactionWithRetry(ctx, fn, retry+1)
+		}
+	}
+
+	if err != nil {
+		sess.Rollback()
+		return err
+	}
+
+	if err = sess.Commit(); err != nil {
+		return err
+	}
+
+	if len(sess.events) > 0 {
+		for _, e := range sess.events {
+			if err = bus.Publish(e); err != nil {
+				ss.log.Error("Failed to publish event after commit", err)
+			}
+		}
+	}
+
+	return nil
+}
+
+func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
+	return inTransactionWithRetryCtx(context.Background(), callback, retry)
+}
+
+func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error {
+	sess, err := startSession(ctx, x, true)
+	if err != nil {
+		return err
+	}
+
+	defer sess.Close()
+
+	err = callback(sess)
+
+	// special handling of database locked errors for sqlite, then we can retry 3 times
+	if sqlError, ok := err.(sqlite3.Error); ok && retry < 5 {
+		if sqlError.Code == sqlite3.ErrLocked {
+			sess.Rollback()
+			time.Sleep(time.Millisecond * time.Duration(10))
+			sqlog.Info("Database table locked, sleeping then retrying", "retry", retry)
+			return inTransactionWithRetry(callback, retry+1)
+		}
+	}
+
+	if err != nil {
+		sess.Rollback()
+		return err
+	} else if err = sess.Commit(); err != nil {
+		return err
+	}
+
+	if len(sess.events) > 0 {
+		for _, e := range sess.events {
+			if err = bus.Publish(e); err != nil {
+				log.Error(3, "Failed to publish event after commit", err)
+			}
+		}
+	}
+
+	return nil
+}
+
+func inTransaction(callback dbTransactionFunc) error {
+	return inTransactionWithRetry(callback, 0)
+}
+
+func inTransactionCtx(ctx context.Context, callback dbTransactionFunc) error {
+	return inTransactionWithRetryCtx(ctx, callback, 0)
+}

+ 60 - 0
pkg/services/sqlstore/transactions_test.go

@@ -0,0 +1,60 @@
+package sqlstore
+
+import (
+	"context"
+	"errors"
+	"testing"
+
+	"github.com/grafana/grafana/pkg/models"
+
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+type testQuery struct {
+	result bool
+}
+
+var ProvokedError = errors.New("testing error.")
+
+func TestTransaction(t *testing.T) {
+	ss := InitTestDB(t)
+
+	Convey("InTransaction asdf asdf", t, func() {
+		cmd := &models.AddApiKeyCommand{Key: "secret-key", Name: "key", OrgId: 1}
+
+		err := AddApiKey(cmd)
+		So(err, ShouldBeNil)
+
+		deleteApiKeyCmd := &models.DeleteApiKeyCommand{Id: cmd.Result.Id, OrgId: 1}
+
+		Convey("can update key", func() {
+			err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
+				return DeleteApiKeyCtx(ctx, deleteApiKeyCmd)
+			})
+
+			So(err, ShouldBeNil)
+
+			query := &models.GetApiKeyByIdQuery{ApiKeyId: cmd.Result.Id}
+			err = GetApiKeyById(query)
+			So(err, ShouldEqual, models.ErrInvalidApiKey)
+		})
+
+		Convey("wont update if one handler fails", func() {
+			err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
+				err := DeleteApiKeyCtx(ctx, deleteApiKeyCmd)
+				if err != nil {
+					return err
+				}
+
+				return ProvokedError
+			})
+
+			So(err, ShouldEqual, ProvokedError)
+
+			query := &models.GetApiKeyByIdQuery{ApiKeyId: cmd.Result.Id}
+			err = GetApiKeyById(query)
+			So(err, ShouldBeNil)
+			So(query.Result.Id, ShouldEqual, cmd.Result.Id)
+		})
+	})
+}

+ 5 - 3
pkg/services/sqlstore/user.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"strconv"
 	"strings"
 	"time"
@@ -15,7 +16,7 @@ import (
 )
 
 func init() {
-	bus.AddHandler("sql", CreateUser)
+	//bus.AddHandler("sql", CreateUser)
 	bus.AddHandler("sql", GetUserById)
 	bus.AddHandler("sql", UpdateUser)
 	bus.AddHandler("sql", ChangeUserPassword)
@@ -30,6 +31,7 @@ func init() {
 	bus.AddHandler("sql", DeleteUser)
 	bus.AddHandler("sql", UpdateUserPermissions)
 	bus.AddHandler("sql", SetUserHelpFlag)
+	bus.AddHandlerCtx("sql", CreateUser)
 }
 
 func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error) {
@@ -79,8 +81,8 @@ func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error
 	return org.Id, nil
 }
 
-func CreateUser(cmd *m.CreateUserCommand) error {
-	return inTransaction(func(sess *DBSession) error {
+func CreateUser(ctx context.Context, cmd *m.CreateUserCommand) error {
+	return inTransactionCtx(ctx, func(sess *DBSession) error {
 		orgId, err := getOrgIdForNewUser(cmd, sess)
 		if err != nil {
 			return err

+ 2 - 1
pkg/services/sqlstore/user_auth_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"fmt"
 	"testing"
 
@@ -22,7 +23,7 @@ func TestUserAuth(t *testing.T) {
 				Name:  fmt.Sprint("user", i),
 				Login: fmt.Sprint("loginuser", i),
 			}
-			err = CreateUser(cmd)
+			err = CreateUser(context.Background(), cmd)
 			So(err, ShouldBeNil)
 			users = append(users, cmd.Result)
 		}

+ 2 - 1
pkg/services/sqlstore/user_test.go

@@ -1,6 +1,7 @@
 package sqlstore
 
 import (
+	"context"
 	"fmt"
 	"testing"
 
@@ -24,7 +25,7 @@ func TestUserDataAccess(t *testing.T) {
 					Name:  fmt.Sprint("user", i),
 					Login: fmt.Sprint("loginuser", i),
 				}
-				err = CreateUser(cmd)
+				err = CreateUser(context.Background(), cmd)
 				So(err, ShouldBeNil)
 				users = append(users, cmd.Result)
 			}