소스 검색

Auth: Enable retries and transaction for some db calls for auth tokens (#16785)

the WithSession wrapper handles retries and connection
management so the caller dont have to worry about it.
Carl Bergquist 6 년 전
부모
커밋
9660356638

+ 1 - 1
pkg/api/admin_users.go

@@ -119,7 +119,7 @@ func (server *HTTPServer) AdminLogoutUser(c *m.ReqContext) Response {
 		return Error(400, "You cannot logout yourself", nil)
 	}
 
-	return server.logoutUserFromAllDevicesInternal(userID)
+	return server.logoutUserFromAllDevicesInternal(c.Req.Context(), userID)
 }
 
 // GET /api/admin/users/:id/auth-tokens

+ 2 - 2
pkg/api/login.go

@@ -131,7 +131,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) {
 		hs.log.Error("user login with nil user")
 	}
 
-	userToken, err := hs.AuthTokenService.CreateToken(user.Id, c.RemoteAddr(), c.Req.UserAgent())
+	userToken, err := hs.AuthTokenService.CreateToken(c.Req.Context(), user.Id, c.RemoteAddr(), c.Req.UserAgent())
 	if err != nil {
 		hs.log.Error("failed to create auth token", "error", err)
 	}
@@ -140,7 +140,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) {
 }
 
 func (hs *HTTPServer) Logout(c *m.ReqContext) {
-	if err := hs.AuthTokenService.RevokeToken(c.UserToken); err != nil && err != m.ErrUserTokenNotFound {
+	if err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken); err != nil && err != m.ErrUserTokenNotFound {
 		hs.log.Error("failed to revoke auth token", "error", err)
 	}
 

+ 6 - 5
pkg/api/user_token.go

@@ -1,6 +1,7 @@
 package api
 
 import (
+	"context"
 	"time"
 
 	"github.com/grafana/grafana/pkg/api/dtos"
@@ -19,7 +20,7 @@ func (server *HTTPServer) RevokeUserAuthToken(c *models.ReqContext, cmd models.R
 	return server.revokeUserAuthTokenInternal(c, c.UserId, cmd)
 }
 
-func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Response {
+func (server *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, userID int64) Response {
 	userQuery := models.GetUserByIdQuery{Id: userID}
 
 	if err := bus.Dispatch(&userQuery); err != nil {
@@ -29,7 +30,7 @@ func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Respons
 		return Error(500, "Could not read user from database", err)
 	}
 
-	err := server.AuthTokenService.RevokeAllUserTokens(userID)
+	err := server.AuthTokenService.RevokeAllUserTokens(ctx, userID)
 	if err != nil {
 		return Error(500, "Failed to logout user", err)
 	}
@@ -49,7 +50,7 @@ func (server *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID
 		return Error(500, "Failed to get user", err)
 	}
 
-	tokens, err := server.AuthTokenService.GetUserTokens(userID)
+	tokens, err := server.AuthTokenService.GetUserTokens(c.Req.Context(), userID)
 	if err != nil {
 		return Error(500, "Failed to get user auth tokens", err)
 	}
@@ -84,7 +85,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user
 		return Error(500, "Failed to get user", err)
 	}
 
-	token, err := server.AuthTokenService.GetUserToken(userID, cmd.AuthTokenId)
+	token, err := server.AuthTokenService.GetUserToken(c.Req.Context(), userID, cmd.AuthTokenId)
 	if err != nil {
 		if err == models.ErrUserTokenNotFound {
 			return Error(404, "User auth token not found", err)
@@ -96,7 +97,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user
 		return Error(400, "Cannot revoke active user auth token", nil)
 	}
 
-	err = server.AuthTokenService.RevokeToken(token)
+	err = server.AuthTokenService.RevokeToken(c.Req.Context(), token)
 	if err != nil {
 		if err == models.ErrUserTokenNotFound {
 			return Error(404, "User auth token not found", err)

+ 5 - 4
pkg/api/user_token_test.go

@@ -1,6 +1,7 @@
 package api
 
 import (
+	"context"
 	"testing"
 	"time"
 
@@ -75,7 +76,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
 		token := &m.UserToken{Id: 1}
 
 		revokeUserAuthTokenInternalScenario("Should be successful", cmd, 200, token, func(sc *scenarioContext) {
-			sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) {
+			sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) {
 				return &m.UserToken{Id: 2}, nil
 			}
 			sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
@@ -93,7 +94,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
 		token := &m.UserToken{Id: 2}
 
 		revokeUserAuthTokenInternalScenario("Should not be successful", cmd, TestUserID, token, func(sc *scenarioContext) {
-			sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) {
+			sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) {
 				return token, nil
 			}
 			sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
@@ -126,7 +127,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
 					SeenAt:    time.Now().Unix(),
 				},
 			}
-			sc.userAuthTokenService.GetUserTokensProvider = func(userId int64) ([]*m.UserToken, error) {
+			sc.userAuthTokenService.GetUserTokensProvider = func(ctx context.Context, userId int64) ([]*m.UserToken, error) {
 				return tokens, nil
 			}
 			sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec()
@@ -226,7 +227,7 @@ func logoutUserFromAllDevicesInternalScenario(desc string, userId int64, fn scen
 			sc.context.OrgId = TestOrgID
 			sc.context.OrgRole = m.ROLE_ADMIN
 
-			return hs.logoutUserFromAllDevicesInternal(userId)
+			return hs.logoutUserFromAllDevicesInternal(context.Background(), userId)
 		})
 
 		sc.m.Post("/", sc.defaultHandler)

+ 2 - 2
pkg/middleware/middleware.go

@@ -173,7 +173,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext
 		return false
 	}
 
-	token, err := authTokenService.LookupToken(rawToken)
+	token, err := authTokenService.LookupToken(ctx.Req.Context(), rawToken)
 	if err != nil {
 		ctx.Logger.Error("failed to look up user based on cookie", "error", err)
 		WriteSessionCookie(ctx, "", -1)
@@ -190,7 +190,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext
 	ctx.IsSignedIn = true
 	ctx.UserToken = token
 
-	rotated, err := authTokenService.TryRotateToken(token, ctx.RemoteAddr(), ctx.Req.UserAgent())
+	rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ctx.RemoteAddr(), ctx.Req.UserAgent())
 	if err != nil {
 		ctx.Logger.Error("failed to rotate token", "error", err)
 		return true

+ 5 - 4
pkg/middleware/middleware_test.go

@@ -1,6 +1,7 @@
 package middleware
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"net/http"
@@ -156,7 +157,7 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: unhashedToken,
@@ -185,14 +186,14 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",
 				}, nil
 			}
 
-			sc.userAuthTokenService.TryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
+			sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
 				userToken.UnhashedToken = "rotated"
 				return true, nil
 			}
@@ -227,7 +228,7 @@ func TestMiddlewareContext(t *testing.T) {
 		middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) {
 			sc.withTokenSessionCookie("token")
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return nil, m.ErrUserTokenNotFound
 			}
 

+ 3 - 2
pkg/middleware/org_redirect_test.go

@@ -1,6 +1,7 @@
 package middleware
 
 import (
+	"context"
 	"fmt"
 	"testing"
 
@@ -23,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        0,
 					UnhashedToken: "",
@@ -49,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",

+ 2 - 1
pkg/middleware/quota_test.go

@@ -1,6 +1,7 @@
 package middleware
 
 import (
+	"context"
 	"testing"
 
 	"github.com/grafana/grafana/pkg/bus"
@@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
+			sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
 				return &m.UserToken{
 					UserId:        12,
 					UnhashedToken: "",

+ 9 - 8
pkg/models/user_token.go

@@ -1,6 +1,7 @@
 package models
 
 import (
+	"context"
 	"errors"
 )
 
@@ -31,12 +32,12 @@ type RevokeAuthTokenCmd struct {
 
 // UserTokenService are used for generating and validating user tokens
 type UserTokenService interface {
-	CreateToken(userId int64, clientIP, userAgent string) (*UserToken, error)
-	LookupToken(unhashedToken string) (*UserToken, error)
-	TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
-	RevokeToken(token *UserToken) error
-	RevokeAllUserTokens(userId int64) error
-	ActiveTokenCount() (int64, error)
-	GetUserToken(userId, userTokenId int64) (*UserToken, error)
-	GetUserTokens(userId int64) ([]*UserToken, error)
+	CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*UserToken, error)
+	LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error)
+	TryRotateToken(ctx context.Context, token *UserToken, clientIP, userAgent string) (bool, error)
+	RevokeToken(ctx context.Context, token *UserToken) error
+	RevokeAllUserTokens(ctx context.Context, userId int64) error
+	ActiveTokenCount(ctx context.Context) (int64, error)
+	GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error)
+	GetUserTokens(ctx context.Context, userId int64) ([]*UserToken, error)
 }

+ 131 - 50
pkg/services/auth/auth_token.go

@@ -1,6 +1,7 @@
 package auth
 
 import (
+	"context"
 	"crypto/sha256"
 	"encoding/hex"
 	"time"
@@ -35,14 +36,24 @@ func (s *UserAuthTokenService) Init() error {
 	return nil
 }
 
-func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) {
-	var model userAuthToken
-	count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model)
+func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) {
+
+	var count int64
+	var err error
+	err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		var model userAuthToken
+		count, err = dbSession.Where(`created_at > ? AND rotated_at > ?`,
+			s.createdAfterParam(),
+			s.rotatedAfterParam()).
+			Count(&model)
+
+		return err
+	})
 
 	return count, err
 }
 
-func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
+func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
 	clientIP = util.ParseIPAddress(clientIP)
 	token, err := util.RandomHex(16)
 	if err != nil {
@@ -65,7 +76,12 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
 		SeenAt:        0,
 		AuthTokenSeen: false,
 	}
-	_, err = s.SQLStore.NewSession().Insert(&userAuthToken)
+
+	err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		_, err = dbSession.Insert(&userAuthToken)
+		return err
+	})
+
 	if err != nil {
 		return nil, err
 	}
@@ -80,14 +96,27 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
 	return &userToken, err
 }
 
-func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) {
+func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
 	hashedToken := hashToken(unhashedToken)
 	if setting.Env == setting.DEV {
 		s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
 	}
 
 	var model userAuthToken
-	exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model)
+	var exists bool
+	var err error
+	err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?",
+			hashedToken,
+			hashedToken,
+			s.createdAfterParam(),
+			s.rotatedAfterParam()).
+			Get(&model)
+
+		return err
+
+	})
+
 	if err != nil {
 		return nil, err
 	}
@@ -100,7 +129,18 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
 		modelCopy := model
 		modelCopy.AuthTokenSeen = false
 		expireBefore := getTime().Add(-urgentRotateTime).Unix()
-		affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND prev_auth_token = ? AND rotated_at < ?", modelCopy.Id, modelCopy.PrevAuthToken, expireBefore).AllCols().Update(&modelCopy)
+
+		var affectedRows int64
+		err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+			affectedRows, err = dbSession.Where("id = ? AND prev_auth_token = ? AND rotated_at < ?",
+				modelCopy.Id,
+				modelCopy.PrevAuthToken,
+				expireBefore).
+				AllCols().Update(&modelCopy)
+
+			return err
+		})
+
 		if err != nil {
 			return nil, err
 		}
@@ -116,7 +156,17 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
 		modelCopy := model
 		modelCopy.AuthTokenSeen = true
 		modelCopy.SeenAt = getTime().Unix()
-		affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND auth_token = ?", modelCopy.Id, modelCopy.AuthToken).AllCols().Update(&modelCopy)
+
+		var affectedRows int64
+		err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+			affectedRows, err = dbSession.Where("id = ? AND auth_token = ?",
+				modelCopy.Id,
+				modelCopy.AuthToken).
+				AllCols().Update(&modelCopy)
+
+			return err
+		})
+
 		if err != nil {
 			return nil, err
 		}
@@ -140,7 +190,7 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
 	return &userToken, err
 }
 
-func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) {
+func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
 	if token == nil {
 		return false, nil
 	}
@@ -183,12 +233,21 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP,
 			rotated_at = ?
 		WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
 
-	res, err := s.SQLStore.NewSession().Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix())
+	var affected int64
+	err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		res, err := dbSession.Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix())
+		if err != nil {
+			return err
+		}
+
+		affected, err = res.RowsAffected()
+		return err
+	})
+
 	if err != nil {
 		return false, err
 	}
 
-	affected, _ := res.RowsAffected()
 	s.log.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
 	if affected > 0 {
 		model.UnhashedToken = newToken
@@ -199,14 +258,20 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP,
 	return false, nil
 }
 
-func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
+func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error {
 	if token == nil {
 		return models.ErrUserTokenNotFound
 	}
 
 	model := userAuthTokenFromUserToken(token)
 
-	rowsAffected, err := s.SQLStore.NewSession().Delete(model)
+	var rowsAffected int64
+	var err error
+	err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		rowsAffected, err = dbSession.Delete(model)
+		return err
+	})
+
 	if err != nil {
 		return err
 	}
@@ -221,55 +286,71 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
 	return nil
 }
 
-func (s *UserAuthTokenService) RevokeAllUserTokens(userId int64) error {
-	sql := `DELETE from user_auth_token WHERE user_id = ?`
-	res, err := s.SQLStore.NewSession().Exec(sql, userId)
-	if err != nil {
-		return err
-	}
+func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error {
+	return s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		sql := `DELETE from user_auth_token WHERE user_id = ?`
+		res, err := dbSession.Exec(sql, userId)
+		if err != nil {
+			return err
+		}
 
-	affected, err := res.RowsAffected()
-	if err != nil {
-		return err
-	}
+		affected, err := res.RowsAffected()
+		if err != nil {
+			return err
+		}
 
-	s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected)
+		s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected)
 
-	return nil
+		return err
+	})
 }
 
-func (s *UserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
-	var token userAuthToken
-	exists, err := s.SQLStore.NewSession().Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
-	if err != nil {
-		return nil, err
-	}
-
-	if !exists {
-		return nil, models.ErrUserTokenNotFound
-	}
+func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
 
 	var result models.UserToken
-	token.toUserToken(&result)
+	err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		var token userAuthToken
+		exists, err := dbSession.Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
+		if err != nil {
+			return err
+		}
+
+		if !exists {
+			return models.ErrUserTokenNotFound
+		}
+
+		token.toUserToken(&result)
+		return nil
+	})
 
-	return &result, nil
+	return &result, err
 }
 
-func (s *UserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
-	var tokens []*userAuthToken
-	err := s.SQLStore.NewSession().Where("user_id = ? AND created_at > ? AND rotated_at > ?", userId, s.createdAfterParam(), s.rotatedAfterParam()).Find(&tokens)
-	if err != nil {
-		return nil, err
-	}
+func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) {
 
 	result := []*models.UserToken{}
-	for _, token := range tokens {
-		var userToken models.UserToken
-		token.toUserToken(&userToken)
-		result = append(result, &userToken)
-	}
+	err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		var tokens []*userAuthToken
+		err := dbSession.Where("user_id = ? AND created_at > ? AND rotated_at > ?",
+			userId,
+			s.createdAfterParam(),
+			s.rotatedAfterParam()).
+			Find(&tokens)
+
+		if err != nil {
+			return err
+		}
+
+		for _, token := range tokens {
+			var userToken models.UserToken
+			token.toUserToken(&userToken)
+			result = append(result, &userToken)
+		}
+
+		return nil
+	})
 
-	return result, nil
+	return result, err
 }
 
 func (s *UserAuthTokenService) createdAfterParam() int64 {

+ 42 - 41
pkg/services/auth/auth_token_test.go

@@ -1,6 +1,7 @@
 package auth
 
 import (
+	"context"
 	"encoding/json"
 	"testing"
 	"time"
@@ -26,19 +27,19 @@ func TestUserAuthToken(t *testing.T) {
 		}
 
 		Convey("When creating token", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 			So(userToken, ShouldNotBeNil)
 			So(userToken.AuthTokenSeen, ShouldBeFalse)
 
 			Convey("Can count active tokens", func() {
-				count, err := userAuthTokenService.ActiveTokenCount()
+				count, err := userAuthTokenService.ActiveTokenCount(context.Background())
 				So(err, ShouldBeNil)
 				So(count, ShouldEqual, 1)
 			})
 
 			Convey("When lookup unhashed token should return user auth token", func() {
-				userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+				userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(userToken, ShouldNotBeNil)
 				So(userToken.UserId, ShouldEqual, userID)
@@ -51,13 +52,13 @@ func TestUserAuthToken(t *testing.T) {
 			})
 
 			Convey("When lookup hashed token should return user auth token not found error", func() {
-				userToken, err := userAuthTokenService.LookupToken(userToken.AuthToken)
+				userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.AuthToken)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 				So(userToken, ShouldBeNil)
 			})
 
 			Convey("revoking existing token should delete token", func() {
-				err = userAuthTokenService.RevokeToken(userToken)
+				err = userAuthTokenService.RevokeToken(context.Background(), userToken)
 				So(err, ShouldBeNil)
 
 				model, err := ctx.getAuthTokenByID(userToken.Id)
@@ -66,37 +67,37 @@ func TestUserAuthToken(t *testing.T) {
 			})
 
 			Convey("revoking nil token should return error", func() {
-				err = userAuthTokenService.RevokeToken(nil)
+				err = userAuthTokenService.RevokeToken(context.Background(), nil)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 			})
 
 			Convey("revoking non-existing token should return error", func() {
 				userToken.Id = 1000
-				err = userAuthTokenService.RevokeToken(userToken)
+				err = userAuthTokenService.RevokeToken(context.Background(), userToken)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 			})
 
 			Convey("When creating an additional token", func() {
-				userToken2, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+				userToken2, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 				So(err, ShouldBeNil)
 				So(userToken2, ShouldNotBeNil)
 
 				Convey("Can get first user token", func() {
-					token, err := userAuthTokenService.GetUserToken(userID, userToken.Id)
+					token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken.Id)
 					So(err, ShouldBeNil)
 					So(token, ShouldNotBeNil)
 					So(token.Id, ShouldEqual, userToken.Id)
 				})
 
 				Convey("Can get second user token", func() {
-					token, err := userAuthTokenService.GetUserToken(userID, userToken2.Id)
+					token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken2.Id)
 					So(err, ShouldBeNil)
 					So(token, ShouldNotBeNil)
 					So(token.Id, ShouldEqual, userToken2.Id)
 				})
 
 				Convey("Can get user tokens", func() {
-					tokens, err := userAuthTokenService.GetUserTokens(userID)
+					tokens, err := userAuthTokenService.GetUserTokens(context.Background(), userID)
 					So(err, ShouldBeNil)
 					So(tokens, ShouldHaveLength, 2)
 					So(tokens[0].Id, ShouldEqual, userToken.Id)
@@ -104,7 +105,7 @@ func TestUserAuthToken(t *testing.T) {
 				})
 
 				Convey("Can revoke all user tokens", func() {
-					err := userAuthTokenService.RevokeAllUserTokens(userID)
+					err := userAuthTokenService.RevokeAllUserTokens(context.Background(), userID)
 					So(err, ShouldBeNil)
 
 					model, err := ctx.getAuthTokenByID(userToken.Id)
@@ -119,24 +120,24 @@ func TestUserAuthToken(t *testing.T) {
 		})
 
 		Convey("expires correctly", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 
-			userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 
 			getTime = func() time.Time {
 				return t.Add(time.Hour)
 			}
 
-			rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.11:1234", "some user agent")
+			rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
-			userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 
-			stillGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			stillGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(stillGood, ShouldNotBeNil)
 
@@ -148,7 +149,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour).Add(-time.Second)
 				}
 
-				stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken)
+				stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(stillGood, ShouldNotBeNil)
 			})
@@ -158,12 +159,12 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour)
 				}
 
-				notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+				notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 				So(notGood, ShouldBeNil)
 
 				Convey("should not find active token when expired", func() {
-					count, err := userAuthTokenService.ActiveTokenCount()
+					count, err := userAuthTokenService.ActiveTokenCount(context.Background())
 					So(err, ShouldBeNil)
 					So(count, ShouldEqual, 0)
 				})
@@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour).Add(-time.Second)
 				}
 
-				stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken)
+				stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(stillGood, ShouldNotBeNil)
 			})
@@ -192,20 +193,20 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour)
 				}
 
-				notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+				notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 				So(err, ShouldEqual, models.ErrUserTokenNotFound)
 				So(notGood, ShouldBeNil)
 			})
 		})
 
 		Convey("can properly rotate tokens", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 
 			prevToken := userToken.AuthToken
 			unhashedPrev := userToken.UnhashedToken
 
-			rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent")
+			rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeFalse)
 
@@ -224,7 +225,7 @@ func TestUserAuthToken(t *testing.T) {
 				return t.Add(time.Hour)
 			}
 
-			rotated, err = userAuthTokenService.TryRotateToken(&tok, "192.168.10.12:1234", "a new user agent")
+			rotated, err = userAuthTokenService.TryRotateToken(context.Background(), &tok, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
@@ -243,13 +244,13 @@ func TestUserAuthToken(t *testing.T) {
 
 			// ability to auth using an old token
 
-			lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
+			lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), model.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 			So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
 			So(lookedUpUserToken.SeenAt, ShouldEqual, getTime().Unix())
 
-			lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev)
+			lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 			So(lookedUpUserToken.Id, ShouldEqual, model.Id)
@@ -259,7 +260,7 @@ func TestUserAuthToken(t *testing.T) {
 				return t.Add(time.Hour + (2 * time.Minute))
 			}
 
-			lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev)
+			lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 			So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
@@ -269,7 +270,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(lookedUpModel, ShouldNotBeNil)
 			So(lookedUpModel.AuthTokenSeen, ShouldBeFalse)
 
-			rotated, err = userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent")
+			rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
@@ -280,11 +281,11 @@ func TestUserAuthToken(t *testing.T) {
 		})
 
 		Convey("keeps prev token valid for 1 minute after it is confirmed", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 			So(userToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 
@@ -293,7 +294,7 @@ func TestUserAuthToken(t *testing.T) {
 			}
 
 			prevToken := userToken.UnhashedToken
-			rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox")
+			rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
@@ -301,25 +302,25 @@ func TestUserAuthToken(t *testing.T) {
 				return t.Add(20 * time.Minute)
 			}
 
-			currentUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			currentUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(currentUserToken, ShouldNotBeNil)
 
-			prevUserToken, err := userAuthTokenService.LookupToken(prevToken)
+			prevUserToken, err := userAuthTokenService.LookupToken(context.Background(), prevToken)
 			So(err, ShouldBeNil)
 			So(prevUserToken, ShouldNotBeNil)
 		})
 
 		Convey("will not mark token unseen when prev and current are the same", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 			So(userToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
+			lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 
@@ -330,7 +331,7 @@ func TestUserAuthToken(t *testing.T) {
 		})
 
 		Convey("Rotate token", func() {
-			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 			So(userToken, ShouldNotBeNil)
 
@@ -345,7 +346,7 @@ func TestUserAuthToken(t *testing.T) {
 					return t.Add(10 * time.Minute)
 				}
 
-				rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox")
+				rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 
@@ -366,7 +367,7 @@ func TestUserAuthToken(t *testing.T) {
 					return t.Add(20 * time.Minute)
 				}
 
-				rotated, err = userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox")
+				rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 
@@ -385,7 +386,7 @@ func TestUserAuthToken(t *testing.T) {
 					return t.Add(2 * time.Minute)
 				}
 
-				rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox")
+				rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 

+ 37 - 33
pkg/services/auth/testing.go

@@ -1,81 +1,85 @@
 package auth
 
-import "github.com/grafana/grafana/pkg/models"
+import (
+	"context"
+
+	"github.com/grafana/grafana/pkg/models"
+)
 
 type FakeUserAuthTokenService struct {
-	CreateTokenProvider         func(userId int64, clientIP, userAgent string) (*models.UserToken, error)
-	TryRotateTokenProvider      func(token *models.UserToken, clientIP, userAgent string) (bool, error)
-	LookupTokenProvider         func(unhashedToken string) (*models.UserToken, error)
-	RevokeTokenProvider         func(token *models.UserToken) error
-	RevokeAllUserTokensProvider func(userId int64) error
-	ActiveAuthTokenCount        func() (int64, error)
-	GetUserTokenProvider        func(userId, userTokenId int64) (*models.UserToken, error)
-	GetUserTokensProvider       func(userId int64) ([]*models.UserToken, error)
+	CreateTokenProvider         func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error)
+	TryRotateTokenProvider      func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error)
+	LookupTokenProvider         func(ctx context.Context, unhashedToken string) (*models.UserToken, error)
+	RevokeTokenProvider         func(ctx context.Context, token *models.UserToken) error
+	RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error
+	ActiveAuthTokenCount        func(ctx context.Context) (int64, error)
+	GetUserTokenProvider        func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error)
+	GetUserTokensProvider       func(ctx context.Context, userId int64) ([]*models.UserToken, error)
 }
 
 func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
 	return &FakeUserAuthTokenService{
-		CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
+		CreateTokenProvider: func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
 			return &models.UserToken{
 				UserId:        0,
 				UnhashedToken: "",
 			}, nil
 		},
-		TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) {
+		TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
 			return false, nil
 		},
-		LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) {
+		LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
 			return &models.UserToken{
 				UserId:        0,
 				UnhashedToken: "",
 			}, nil
 		},
-		RevokeTokenProvider: func(token *models.UserToken) error {
+		RevokeTokenProvider: func(ctx context.Context, token *models.UserToken) error {
 			return nil
 		},
-		RevokeAllUserTokensProvider: func(userId int64) error {
+		RevokeAllUserTokensProvider: func(ctx context.Context, userId int64) error {
 			return nil
 		},
-		ActiveAuthTokenCount: func() (int64, error) {
+		ActiveAuthTokenCount: func(ctx context.Context) (int64, error) {
 			return 10, nil
 		},
-		GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) {
+		GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
 			return nil, nil
 		},
-		GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) {
+		GetUserTokensProvider: func(ctx context.Context, userId int64) ([]*models.UserToken, error) {
 			return nil, nil
 		},
 	}
 }
 
-func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
-	return s.CreateTokenProvider(userId, clientIP, userAgent)
+func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
+	return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent)
 }
 
-func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) {
-	return s.LookupTokenProvider(unhashedToken)
+func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
+	return s.LookupTokenProvider(context.Background(), unhashedToken)
 }
 
-func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) {
-	return s.TryRotateTokenProvider(token, clientIP, userAgent)
+func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
+	return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent)
 }
 
-func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error {
-	return s.RevokeTokenProvider(token)
+func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error {
+	return s.RevokeTokenProvider(context.Background(), token)
 }
 
-func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error {
-	return s.RevokeAllUserTokensProvider(userId)
+func (s *FakeUserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error {
+	return s.RevokeAllUserTokensProvider(context.Background(), userId)
 }
 
-func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
-	return s.ActiveAuthTokenCount()
+func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) {
+	return s.ActiveAuthTokenCount(context.Background())
 }
 
-func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
-	return s.GetUserTokenProvider(userId, userTokenId)
+func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
+	return s.GetUserTokenProvider(context.Background(), userId, userTokenId)
 }
 
-func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
-	return s.GetUserTokensProvider(userId)
+func (s *FakeUserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) {
+	return s.GetUserTokensProvider(context.Background(), userId)
 }

+ 25 - 16
pkg/services/auth/token_cleanup.go

@@ -3,6 +3,8 @@ package auth
 import (
 	"context"
 	"time"
+
+	"github.com/grafana/grafana/pkg/services/sqlstore"
 )
 
 func (srv *UserAuthTokenService) Run(ctx context.Context) error {
@@ -11,21 +13,22 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error {
 	maxLifetime := time.Duration(srv.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
 
 	err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
-		srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime)
+		srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime)
 	})
+
 	if err != nil {
-		srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err)
+		srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err)
 	}
 
 	for {
 		select {
 		case <-ticker.C:
 			err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
-				srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime)
+				srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime)
 			})
 
 			if err != nil {
-				srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err)
+				srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err)
 			}
 
 		case <-ctx.Done():
@@ -34,24 +37,30 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error {
 	}
 }
 
-func (srv *UserAuthTokenService) deleteExpiredTokens(maxInactiveLifetime, maxLifetime time.Duration) (int64, error) {
+func (srv *UserAuthTokenService) deleteExpiredTokens(ctx context.Context, maxInactiveLifetime, maxLifetime time.Duration) (int64, error) {
 	createdBefore := getTime().Add(-maxLifetime)
 	rotatedBefore := getTime().Add(-maxInactiveLifetime)
 
 	srv.log.Debug("starting cleanup of expired auth tokens", "createdBefore", createdBefore, "rotatedBefore", rotatedBefore)
 
-	sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?`
-	res, err := srv.SQLStore.NewSession().Exec(sql, createdBefore.Unix(), rotatedBefore.Unix())
-	if err != nil {
-		return 0, err
-	}
+	var affected int64
+	err := srv.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?`
+		res, err := dbSession.Exec(sql, createdBefore.Unix(), rotatedBefore.Unix())
+		if err != nil {
+			return err
+		}
 
-	affected, err := res.RowsAffected()
-	if err != nil {
-		srv.log.Error("failed to cleanup expired auth tokens", "error", err)
-		return 0, nil
-	}
+		affected, err = res.RowsAffected()
+		if err != nil {
+			srv.log.Error("failed to cleanup expired auth tokens", "error", err)
+			return nil
+		}
+
+		srv.log.Debug("cleanup of expired auth tokens done", "count", affected)
+
+		return nil
+	})
 
-	srv.log.Debug("cleanup of expired auth tokens done", "count", affected)
 	return affected, err
 }

+ 3 - 2
pkg/services/auth/token_cleanup_test.go

@@ -1,6 +1,7 @@
 package auth
 
 import (
+	"context"
 	"fmt"
 	"testing"
 	"time"
@@ -40,7 +41,7 @@ func TestUserAuthTokenCleanup(t *testing.T) {
 				insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), from.Unix())
 			}
 
-			affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour)
+			affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour)
 			So(err, ShouldBeNil)
 			So(affected, ShouldEqual, 3)
 		})
@@ -60,7 +61,7 @@ func TestUserAuthTokenCleanup(t *testing.T) {
 				insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), fromRotate.Unix())
 			}
 
-			affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour)
+			affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour)
 			So(err, ShouldBeNil)
 			So(affected, ShouldEqual, 3)
 		})

+ 1 - 1
pkg/services/quota/quota.go

@@ -43,7 +43,7 @@ func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, erro
 			}
 			if target == "session" {
 
-				usedSessions, err := qs.AuthTokenService.ActiveTokenCount()
+				usedSessions, err := qs.AuthTokenService.ActiveTokenCount(c.Req.Context())
 				if err != nil {
 					return false, err
 				}