Browse Source

support get user tokens/revoke all user tokens in UserTokenService

Marcus Efraimsson 6 years ago
parent
commit
8029e48588

+ 7 - 60
pkg/middleware/middleware_test.go

@@ -11,6 +11,7 @@ import (
 	msession "github.com/go-macaron/session"
 	"github.com/grafana/grafana/pkg/bus"
 	m "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/services/auth"
 	"github.com/grafana/grafana/pkg/services/session"
 	"github.com/grafana/grafana/pkg/setting"
 	"github.com/grafana/grafana/pkg/util"
@@ -155,7 +156,7 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: unhashedToken,
@@ -184,14 +185,14 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",
 				}, nil
 			}
 
-			sc.userAuthTokenService.tryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
+			sc.userAuthTokenService.TryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
 				userToken.UnhashedToken = "rotated"
 				return true, nil
 			}
@@ -226,7 +227,7 @@ func TestMiddlewareContext(t *testing.T) {
 		middlewareScenario("Invalid/expired auth token in cookie", func(sc *scenarioContext) {
 			sc.withTokenSessionCookie("token")
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return nil, m.ErrUserTokenNotFound
 			}
 
@@ -562,7 +563,7 @@ func middlewareScenario(desc string, fn scenarioFunc) {
 		}))
 
 		session.Init(&msession.Options{}, 0)
-		sc.userAuthTokenService = newFakeUserAuthTokenService()
+		sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
 		sc.m.Use(GetContextHandler(sc.userAuthTokenService))
 		// mock out gc goroutine
 		session.StartSessionGC = func() {}
@@ -595,7 +596,7 @@ type scenarioContext struct {
 	handlerFunc          handlerFunc
 	defaultHandler       macaron.Handler
 	url                  string
-	userAuthTokenService *fakeUserAuthTokenService
+	userAuthTokenService *auth.FakeUserAuthTokenService
 
 	req *http.Request
 }
@@ -676,57 +677,3 @@ func (sc *scenarioContext) exec() {
 
 type scenarioFunc func(c *scenarioContext)
 type handlerFunc func(c *m.ReqContext)
-
-type fakeUserAuthTokenService struct {
-	createTokenProvider    func(userId int64, clientIP, userAgent string) (*m.UserToken, error)
-	tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
-	lookupTokenProvider    func(unhashedToken string) (*m.UserToken, error)
-	revokeTokenProvider    func(token *m.UserToken) error
-	activeAuthTokenCount   func() (int64, error)
-}
-
-func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
-	return &fakeUserAuthTokenService{
-		createTokenProvider: func(userId int64, clientIP, userAgent string) (*m.UserToken, error) {
-			return &m.UserToken{
-				UserId:        0,
-				UnhashedToken: "",
-			}, nil
-		},
-		tryRotateTokenProvider: func(token *m.UserToken, clientIP, userAgent string) (bool, error) {
-			return false, nil
-		},
-		lookupTokenProvider: func(unhashedToken string) (*m.UserToken, error) {
-			return &m.UserToken{
-				UserId:        0,
-				UnhashedToken: "",
-			}, nil
-		},
-		revokeTokenProvider: func(token *m.UserToken) error {
-			return nil
-		},
-		activeAuthTokenCount: func() (int64, error) {
-			return 10, nil
-		},
-	}
-}
-
-func (s *fakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*m.UserToken, error) {
-	return s.createTokenProvider(userId, clientIP, userAgent)
-}
-
-func (s *fakeUserAuthTokenService) LookupToken(unhashedToken string) (*m.UserToken, error) {
-	return s.lookupTokenProvider(unhashedToken)
-}
-
-func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP, userAgent string) (bool, error) {
-	return s.tryRotateTokenProvider(token, clientIP, userAgent)
-}
-
-func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
-	return s.revokeTokenProvider(token)
-}
-
-func (s *fakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
-	return s.activeAuthTokenCount()
-}

+ 2 - 2
pkg/middleware/org_redirect_test.go

@@ -24,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        0,
 					UnhashedToken: "",
@@ -50,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",

+ 3 - 2
pkg/middleware/quota_test.go

@@ -3,6 +3,7 @@ package middleware
 import (
 	"testing"
 
+	"github.com/grafana/grafana/pkg/services/auth"
 	"github.com/grafana/grafana/pkg/services/quota"
 
 	"github.com/grafana/grafana/pkg/bus"
@@ -36,7 +37,7 @@ func TestMiddlewareQuota(t *testing.T) {
 			},
 		}
 
-		fakeAuthTokenService := newFakeUserAuthTokenService()
+		fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
 		qs := &quota.QuotaService{
 			AuthTokenService: fakeAuthTokenService,
 		}
@@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",

+ 2 - 1
pkg/middleware/recovery_test.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/grafana/grafana/pkg/bus"
 	m "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/services/auth"
 	"github.com/grafana/grafana/pkg/setting"
 	. "github.com/smartystreets/goconvey/convey"
 	macaron "gopkg.in/macaron.v1"
@@ -62,7 +63,7 @@ func recoveryScenario(desc string, url string, fn scenarioFunc) {
 			Delims:    macaron.Delims{Left: "[[", Right: "]]"},
 		}))
 
-		sc.userAuthTokenService = newFakeUserAuthTokenService()
+		sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
 		sc.m.Use(GetContextHandler(sc.userAuthTokenService))
 		// mock out gc goroutine
 		sc.m.Use(OrgRedirect())

+ 3 - 0
pkg/models/user_token.go

@@ -29,5 +29,8 @@ type UserTokenService interface {
 	LookupToken(unhashedToken string) (*UserToken, error)
 	TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
 	RevokeToken(token *UserToken) error
+	RevokeAllUserTokens(userId int64) error
 	ActiveTokenCount() (int64, error)
+	GetUserToken(userId, userTokenId int64) (*UserToken, error)
+	GetUserTokens(userId int64) ([]*UserToken, error)
 }

+ 51 - 0
pkg/services/auth/auth_token.go

@@ -221,6 +221,57 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
 	return nil
 }
 
+func (s *UserAuthTokenService) RevokeAllUserTokens(userId int64) error {
+	sql := `DELETE from user_auth_token WHERE user_id = ?`
+	res, err := s.SQLStore.NewSession().Exec(sql, userId)
+	if err != nil {
+		return err
+	}
+
+	affected, err := res.RowsAffected()
+	if err != nil {
+		return err
+	}
+
+	s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected)
+
+	return nil
+}
+
+func (s *UserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
+	var token userAuthToken
+	exists, err := s.SQLStore.NewSession().Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
+	if err != nil {
+		return nil, err
+	}
+
+	if !exists {
+		return nil, models.ErrUserTokenNotFound
+	}
+
+	var result models.UserToken
+	token.toUserToken(&result)
+
+	return &result, nil
+}
+
+func (s *UserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
+	var tokens []*userAuthToken
+	err := s.SQLStore.NewSession().Where("user_id = ? AND created_at > ? AND rotated_at > ?", userId, s.createdAfterParam(), s.rotatedAfterParam()).Find(&tokens)
+	if err != nil {
+		return nil, err
+	}
+
+	result := []*models.UserToken{}
+	for _, token := range tokens {
+		var userToken models.UserToken
+		token.toUserToken(&userToken)
+		result = append(result, &userToken)
+	}
+
+	return result, nil
+}
+
 func (s *UserAuthTokenService) createdAfterParam() int64 {
 	tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
 	return getTime().Add(-tokenMaxLifetime).Unix()

+ 41 - 0
pkg/services/auth/auth_token_test.go

@@ -75,6 +75,47 @@ func TestUserAuthToken(t *testing.T) {
 				err = userAuthTokenService.RevokeToken(userToken)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 			})
+
+			Convey("When creating an additional token", func() {
+				userToken2, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+				So(err, ShouldBeNil)
+				So(userToken2, ShouldNotBeNil)
+
+				Convey("Can get first user token", func() {
+					token, err := userAuthTokenService.GetUserToken(userID, userToken.Id)
+					So(err, ShouldBeNil)
+					So(token, ShouldNotBeNil)
+					So(token.Id, ShouldEqual, userToken.Id)
+				})
+
+				Convey("Can get second user token", func() {
+					token, err := userAuthTokenService.GetUserToken(userID, userToken2.Id)
+					So(err, ShouldBeNil)
+					So(token, ShouldNotBeNil)
+					So(token.Id, ShouldEqual, userToken2.Id)
+				})
+
+				Convey("Can get user tokens", func() {
+					tokens, err := userAuthTokenService.GetUserTokens(userID)
+					So(err, ShouldBeNil)
+					So(tokens, ShouldHaveLength, 2)
+					So(tokens[0].Id, ShouldEqual, userToken.Id)
+					So(tokens[1].Id, ShouldEqual, userToken2.Id)
+				})
+
+				Convey("Can revoke all user tokens", func() {
+					err := userAuthTokenService.RevokeAllUserTokens(userID)
+					So(err, ShouldBeNil)
+
+					model, err := ctx.getAuthTokenByID(userToken.Id)
+					So(err, ShouldBeNil)
+					So(model, ShouldBeNil)
+
+					model2, err := ctx.getAuthTokenByID(userToken2.Id)
+					So(err, ShouldBeNil)
+					So(model2, ShouldBeNil)
+				})
+			})
 		})
 
 		Convey("expires correctly", func() {

+ 81 - 0
pkg/services/auth/testing.go

@@ -0,0 +1,81 @@
+package auth
+
+import "github.com/grafana/grafana/pkg/models"
+
+type FakeUserAuthTokenService struct {
+	CreateTokenProvider         func(userId int64, clientIP, userAgent string) (*models.UserToken, error)
+	TryRotateTokenProvider      func(token *models.UserToken, clientIP, userAgent string) (bool, error)
+	LookupTokenProvider         func(unhashedToken string) (*models.UserToken, error)
+	RevokeTokenProvider         func(token *models.UserToken) error
+	RevokeAllUserTokensProvider func(userId int64) error
+	ActiveAuthTokenCount        func() (int64, error)
+	GetUserTokenProvider        func(userId, userTokenId int64) (*models.UserToken, error)
+	GetUserTokensProvider       func(userId int64) ([]*models.UserToken, error)
+}
+
+func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
+	return &FakeUserAuthTokenService{
+		CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
+			return &models.UserToken{
+				UserId:        0,
+				UnhashedToken: "",
+			}, nil
+		},
+		TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) {
+			return false, nil
+		},
+		LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) {
+			return &models.UserToken{
+				UserId:        0,
+				UnhashedToken: "",
+			}, nil
+		},
+		RevokeTokenProvider: func(token *models.UserToken) error {
+			return nil
+		},
+		RevokeAllUserTokensProvider: func(userId int64) error {
+			return nil
+		},
+		ActiveAuthTokenCount: func() (int64, error) {
+			return 10, nil
+		},
+		GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) {
+			return nil, nil
+		},
+		GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) {
+			return nil, nil
+		},
+	}
+}
+
+func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
+	return s.CreateTokenProvider(userId, clientIP, userAgent)
+}
+
+func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) {
+	return s.LookupTokenProvider(unhashedToken)
+}
+
+func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) {
+	return s.TryRotateTokenProvider(token, clientIP, userAgent)
+}
+
+func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error {
+	return s.RevokeTokenProvider(token)
+}
+
+func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error {
+	return s.RevokeAllUserTokensProvider(userId)
+}
+
+func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
+	return s.ActiveAuthTokenCount()
+}
+
+func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
+	return s.GetUserTokenProvider(userId, userTokenId)
+}
+
+func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
+	return s.GetUserTokensProvider(userId)
+}