Просмотр исходного кода

bus: DispatchCtx can now invoke any handler

bergquist 7 лет назад
Родитель
Сommit
e2275701d8
3 измененных файлов с 69 добавлено и 14 удалено
  1. 1 1
      pkg/api/dashboard.go
  2. 19 5
      pkg/bus/bus.go
  3. 49 8
      pkg/bus/bus_test.go

+ 1 - 1
pkg/api/dashboard.go

@@ -103,7 +103,7 @@ func GetDashboard(c *m.ReqContext) Response {
 	}
 
 	isDashboardProvisioned := &m.IsDashboardProvisionedQuery{DashboardId: dash.Id}
-	err = bus.Dispatch(isDashboardProvisioned)
+	err = bus.DispatchCtx(c.Req.Context(), isDashboardProvisioned)
 	if err != nil {
 		return Error(500, "Error while checking if dashboard is provisioned", err)
 	}

+ 19 - 5
pkg/bus/bus.go

@@ -44,6 +44,7 @@ func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Conte
 
 type InProcBus struct {
 	handlers          map[string]HandlerFunc
+	handlersWithCtx   map[string]HandlerFunc
 	listeners         map[string][]HandlerFunc
 	wildcardListeners []HandlerFunc
 	txMng             TransactionManager
@@ -55,6 +56,7 @@ var globalBus = New()
 func New() Bus {
 	bus := &InProcBus{}
 	bus.handlers = make(map[string]HandlerFunc)
+	bus.handlersWithCtx = make(map[string]HandlerFunc)
 	bus.listeners = make(map[string][]HandlerFunc)
 	bus.wildcardListeners = make([]HandlerFunc, 0)
 	bus.txMng = &noopTransactionManager{}
@@ -74,14 +76,26 @@ func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
 func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
 	var msgName = reflect.TypeOf(msg).Elem().Name()
 
-	var handler = b.handlers[msgName]
+	// we prefer to use the handler that support context.Context
+	var handler = b.handlersWithCtx[msgName]
+	var withCtx = true
+
+	// fallback to use classic handlers
+	if handler == nil {
+		withCtx = false
+		handler = b.handlers[msgName]
+	}
+
 	if handler == nil {
 		return ErrHandlerNotFound
 	}
 
-	var params = make([]reflect.Value, 2)
-	params[0] = reflect.ValueOf(ctx)
-	params[1] = reflect.ValueOf(msg)
+	var params = []reflect.Value{}
+	if withCtx {
+		params = append(params, reflect.ValueOf(ctx))
+	}
+
+	params = append(params, reflect.ValueOf(msg))
 
 	ret := reflect.ValueOf(handler).Call(params)
 	err := ret[0].Interface()
@@ -149,7 +163,7 @@ func (b *InProcBus) AddHandler(handler HandlerFunc) {
 func (b *InProcBus) AddHandlerCtx(handler HandlerFunc) {
 	handlerType := reflect.TypeOf(handler)
 	queryTypeName := handlerType.In(1).Elem().Name()
-	b.handlers[queryTypeName] = handler
+	b.handlersWithCtx[queryTypeName] = handler
 }
 
 func (b *InProcBus) AddEventListener(handler HandlerFunc) {

+ 49 - 8
pkg/bus/bus_test.go

@@ -1,24 +1,65 @@
 package bus
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"testing"
 )
 
-type TestQuery struct {
+type testQuery struct {
 	Id   int64
 	Resp string
 }
 
+func TestDispatchCtxCanUseNormalHandlers(t *testing.T) {
+	bus := New()
+
+	handlerWithCtxCallCount := 0
+	handlerCallCount := 0
+
+	handlerWithCtx := func(ctx context.Context, query *testQuery) error {
+		handlerWithCtxCallCount++
+		return nil
+	}
+
+	handler := func(query *testQuery) error {
+		handlerCallCount++
+		return nil
+	}
+
+	err := bus.DispatchCtx(context.Background(), &testQuery{})
+	if err != ErrHandlerNotFound {
+		t.Errorf("expected bus to return HandlerNotFound is no handler is registered")
+	}
+
+	t.Run("when a normal handler is registered", func(t *testing.T) {
+		bus.AddHandler(handler)
+		bus.DispatchCtx(context.Background(), &testQuery{})
+
+		if handlerCallCount != 1 {
+			t.Errorf("Expected normal handler to be called once")
+		}
+
+		t.Run("when a ctx handler is registered", func(t *testing.T) {
+			bus.AddHandlerCtx(handlerWithCtx)
+			bus.DispatchCtx(context.Background(), &testQuery{})
+
+			if handlerWithCtxCallCount != 1 {
+				t.Errorf("Expected ctx handler to be called once")
+			}
+		})
+	})
+}
+
 func TestQueryHandlerReturnsError(t *testing.T) {
 	bus := New()
 
-	bus.AddHandler(func(query *TestQuery) error {
+	bus.AddHandler(func(query *testQuery) error {
 		return errors.New("handler error")
 	})
 
-	err := bus.Dispatch(&TestQuery{})
+	err := bus.Dispatch(&testQuery{})
 
 	if err == nil {
 		t.Fatal("Send query failed " + err.Error())
@@ -30,12 +71,12 @@ func TestQueryHandlerReturnsError(t *testing.T) {
 func TestQueryHandlerReturn(t *testing.T) {
 	bus := New()
 
-	bus.AddHandler(func(q *TestQuery) error {
+	bus.AddHandler(func(q *testQuery) error {
 		q.Resp = "hello from handler"
 		return nil
 	})
 
-	query := &TestQuery{}
+	query := &testQuery{}
 	err := bus.Dispatch(query)
 
 	if err != nil {
@@ -49,17 +90,17 @@ func TestEventListeners(t *testing.T) {
 	bus := New()
 	count := 0
 
-	bus.AddEventListener(func(query *TestQuery) error {
+	bus.AddEventListener(func(query *testQuery) error {
 		count += 1
 		return nil
 	})
 
-	bus.AddEventListener(func(query *TestQuery) error {
+	bus.AddEventListener(func(query *testQuery) error {
 		count += 10
 		return nil
 	})
 
-	err := bus.Publish(&TestQuery{})
+	err := bus.Publish(&testQuery{})
 
 	if err != nil {
 		t.Fatal("Publish event failed " + err.Error())