Ver código fonte

change UserToken from interface to struct

Marcus Efraimsson 6 anos atrás
pai
commit
a60124a88c

+ 1 - 1
pkg/api/login.go

@@ -137,7 +137,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) {
 		hs.log.Error("failed to create auth token", "error", err)
 	}
 
-	middleware.WriteSessionCookie(c, userToken.GetToken(), hs.Cfg.LoginMaxLifetimeDays)
+	middleware.WriteSessionCookie(c, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetimeDays)
 }
 
 func (hs *HTTPServer) Logout(c *m.ReqContext) {

+ 3 - 3
pkg/middleware/middleware.go

@@ -182,9 +182,9 @@ func initContextWithToken(authTokenService authtoken.UserAuthTokenService, ctx *
 		return false
 	}
 
-	query := m.GetSignedInUserQuery{UserId: token.GetUserId(), OrgId: orgID}
+	query := m.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID}
 	if err := bus.Dispatch(&query); err != nil {
-		ctx.Logger.Error("failed to get user with id", "userId", token.GetUserId(), "error", err)
+		ctx.Logger.Error("failed to get user with id", "userId", token.UserId, "error", err)
 		return false
 	}
 
@@ -199,7 +199,7 @@ func initContextWithToken(authTokenService authtoken.UserAuthTokenService, ctx *
 	}
 
 	if rotated {
-		WriteSessionCookie(ctx, token.GetToken(), setting.LoginMaxLifetimeDays)
+		WriteSessionCookie(ctx, token.UnhashedToken, setting.LoginMaxLifetimeDays)
 	}
 
 	return true

+ 33 - 55
pkg/middleware/middleware_test.go

@@ -157,10 +157,10 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
-				return &userTokenImpl{
-					userId: 12,
-					token:  unhashedToken,
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
+				return &auth.UserToken{
+					UserId:        12,
+					UnhashedToken: unhashedToken,
 				}, nil
 			}
 
@@ -169,8 +169,8 @@ func TestMiddlewareContext(t *testing.T) {
 			Convey("should init context with user info", func() {
 				So(sc.context.IsSignedIn, ShouldBeTrue)
 				So(sc.context.UserId, ShouldEqual, 12)
-				So(sc.context.UserToken.GetUserId(), ShouldEqual, 12)
-				So(sc.context.UserToken.GetToken(), ShouldEqual, "token")
+				So(sc.context.UserToken.UserId, ShouldEqual, 12)
+				So(sc.context.UserToken.UnhashedToken, ShouldEqual, "token")
 			})
 
 			Convey("should not set cookie", func() {
@@ -186,15 +186,15 @@ func TestMiddlewareContext(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
-				return &userTokenImpl{
-					userId: 12,
-					token:  unhashedToken,
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
+				return &auth.UserToken{
+					UserId:        12,
+					UnhashedToken: "",
 				}, nil
 			}
 
-			sc.userAuthTokenService.tryRotateTokenProvider = func(userToken auth.UserToken, clientIP, userAgent string) (bool, error) {
-				userToken.(fakeUserToken).SetToken("rotated")
+			sc.userAuthTokenService.tryRotateTokenProvider = func(userToken *auth.UserToken, clientIP, userAgent string) (bool, error) {
+				userToken.UnhashedToken = "rotated"
 				return true, nil
 			}
 
@@ -216,8 +216,8 @@ func TestMiddlewareContext(t *testing.T) {
 			Convey("should init context with user info", func() {
 				So(sc.context.IsSignedIn, ShouldBeTrue)
 				So(sc.context.UserId, ShouldEqual, 12)
-				So(sc.context.UserToken.GetUserId(), ShouldEqual, 12)
-				So(sc.context.UserToken.GetToken(), ShouldEqual, "rotated")
+				So(sc.context.UserToken.UserId, ShouldEqual, 12)
+				So(sc.context.UserToken.UnhashedToken, ShouldEqual, "rotated")
 			})
 
 			Convey("should set cookie", func() {
@@ -228,7 +228,7 @@ func TestMiddlewareContext(t *testing.T) {
 		middlewareScenario("Invalid/expired auth token in cookie", func(sc *scenarioContext) {
 			sc.withTokenSessionCookie("token")
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
 				return nil, authtoken.ErrAuthTokenNotFound
 			}
 
@@ -679,70 +679,48 @@ func (sc *scenarioContext) exec() {
 type scenarioFunc func(c *scenarioContext)
 type handlerFunc func(c *m.ReqContext)
 
-type fakeUserToken interface {
-	auth.UserToken
-	SetToken(token string)
-}
-
-type userTokenImpl struct {
-	userId int64
-	token  string
-}
-
-func (ut *userTokenImpl) GetUserId() int64 {
-	return ut.userId
-}
-
-func (ut *userTokenImpl) GetToken() string {
-	return ut.token
-}
-
-func (ut *userTokenImpl) SetToken(token string) {
-	ut.token = token
-}
-
 type fakeUserAuthTokenService struct {
-	createTokenProvider    func(userId int64, clientIP, userAgent string) (auth.UserToken, error)
-	tryRotateTokenProvider func(token auth.UserToken, clientIP, userAgent string) (bool, error)
-	lookupTokenProvider    func(unhashedToken string) (auth.UserToken, error)
-	revokeTokenProvider    func(token auth.UserToken) error
+	createTokenProvider    func(userId int64, clientIP, userAgent string) (*auth.UserToken, error)
+	tryRotateTokenProvider func(token *auth.UserToken, clientIP, userAgent string) (bool, error)
+	lookupTokenProvider    func(unhashedToken string) (*auth.UserToken, error)
+	revokeTokenProvider    func(token *auth.UserToken) error
 }
 
 func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
 	return &fakeUserAuthTokenService{
-		createTokenProvider: func(userId int64, clientIP, userAgent string) (auth.UserToken, error) {
-			return &userTokenImpl{
-				userId: 0,
-				token:  "",
+		createTokenProvider: func(userId int64, clientIP, userAgent string) (*auth.UserToken, error) {
+			return &auth.UserToken{
+				UserId:        0,
+				UnhashedToken: "",
 			}, nil
 		},
-		tryRotateTokenProvider: func(token auth.UserToken, clientIP, userAgent string) (bool, error) {
+		tryRotateTokenProvider: func(token *auth.UserToken, clientIP, userAgent string) (bool, error) {
 			return false, nil
 		},
-		lookupTokenProvider: func(unhashedToken string) (auth.UserToken, error) {
-			return &userTokenImpl{
-				userId: 0,
-				token:  "",
+		lookupTokenProvider: func(unhashedToken string) (*auth.UserToken, error) {
+			return &auth.UserToken{
+				UserId:        0,
+				UnhashedToken: "",
 			}, nil
 		},
-		revokeTokenProvider: func(token auth.UserToken) error {
+		revokeTokenProvider: func(token *auth.UserToken) error {
 			return nil
 		},
 	}
 }
 
-func (s *fakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (auth.UserToken, error) {
+func (s *fakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*auth.UserToken, error) {
 	return s.createTokenProvider(userId, clientIP, userAgent)
 }
 
-func (s *fakeUserAuthTokenService) LookupToken(unhashedToken string) (auth.UserToken, error) {
+func (s *fakeUserAuthTokenService) LookupToken(unhashedToken string) (*auth.UserToken, error) {
 	return s.lookupTokenProvider(unhashedToken)
 }
 
-func (s *fakeUserAuthTokenService) TryRotateToken(token auth.UserToken, clientIP, userAgent string) (bool, error) {
+func (s *fakeUserAuthTokenService) TryRotateToken(token *auth.UserToken, clientIP, userAgent string) (bool, error) {
 	return s.tryRotateTokenProvider(token, clientIP, userAgent)
 }
 
-func (s *fakeUserAuthTokenService) RevokeToken(token auth.UserToken) error {
+func (s *fakeUserAuthTokenService) RevokeToken(token *auth.UserToken) error {
 	return s.revokeTokenProvider(token)
 }

+ 8 - 8
pkg/middleware/org_redirect_test.go

@@ -26,10 +26,10 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
-				return &userTokenImpl{
-					userId: 12,
-					token:  "",
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
+				return &auth.UserToken{
+					UserId:        0,
+					UnhashedToken: "",
 				}, nil
 			}
 
@@ -52,10 +52,10 @@ func TestOrgRedirectMiddleware(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
-				return &userTokenImpl{
-					userId: 12,
-					token:  "",
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
+				return &auth.UserToken{
+					UserId:        12,
+					UnhashedToken: "",
 				}, nil
 			}
 

+ 4 - 4
pkg/middleware/quota_test.go

@@ -81,10 +81,10 @@ func TestMiddlewareQuota(t *testing.T) {
 				return nil
 			})
 
-			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (auth.UserToken, error) {
-				return &userTokenImpl{
-					userId: 12,
-					token:  "",
+			sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*auth.UserToken, error) {
+				return &auth.UserToken{
+					UserId:        12,
+					UnhashedToken: "",
 				}, nil
 			}
 

+ 1 - 1
pkg/models/context.go

@@ -14,7 +14,7 @@ import (
 type ReqContext struct {
 	*macaron.Context
 	*SignedInUser
-	UserToken auth.UserToken
+	UserToken *auth.UserToken
 
 	// This should only be used by the auth_proxy
 	Session session.SessionStore

+ 13 - 3
pkg/services/auth/auth.go

@@ -1,6 +1,16 @@
 package auth
 
-type UserToken interface {
-	GetUserId() int64
-	GetToken() string
+type UserToken struct {
+	Id            int64
+	UserId        int64
+	AuthToken     string
+	PrevAuthToken string
+	UserAgent     string
+	ClientIp      string
+	AuthTokenSeen bool
+	SeenAt        int64
+	RotatedAt     int64
+	CreatedAt     int64
+	UpdatedAt     int64
+	UnhashedToken string
 }

+ 16 - 14
pkg/services/auth/authtoken/auth_token.go

@@ -40,7 +40,7 @@ func (s *UserAuthTokenServiceImpl) Init() error {
 	return nil
 }
 
-func (s *UserAuthTokenServiceImpl) CreateToken(userId int64, clientIP, userAgent string) (auth.UserToken, error) {
+func (s *UserAuthTokenServiceImpl) CreateToken(userId int64, clientIP, userAgent string) (*auth.UserToken, error) {
 	clientIP = util.ParseIPAddress(clientIP)
 	token, err := util.RandomHex(16)
 	if err != nil {
@@ -72,10 +72,13 @@ func (s *UserAuthTokenServiceImpl) CreateToken(userId int64, clientIP, userAgent
 
 	s.log.Debug("user auth token created", "tokenId", userAuthToken.Id, "userId", userAuthToken.UserId, "clientIP", userAuthToken.ClientIp, "userAgent", userAuthToken.UserAgent, "authToken", userAuthToken.AuthToken)
 
-	return userAuthToken.toUserToken()
+	var userToken auth.UserToken
+	err = userAuthToken.toUserToken(&userToken)
+
+	return &userToken, err
 }
 
-func (s *UserAuthTokenServiceImpl) LookupToken(unhashedToken string) (auth.UserToken, error) {
+func (s *UserAuthTokenServiceImpl) LookupToken(unhashedToken string) (*auth.UserToken, error) {
 	hashedToken := hashToken(unhashedToken)
 	if setting.Env == setting.DEV {
 		s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
@@ -133,18 +136,19 @@ func (s *UserAuthTokenServiceImpl) LookupToken(unhashedToken string) (auth.UserT
 	}
 
 	model.UnhashedToken = unhashedToken
-	return model.toUserToken()
+
+	var userToken auth.UserToken
+	err = model.toUserToken(&userToken)
+
+	return &userToken, err
 }
 
-func (s *UserAuthTokenServiceImpl) TryRotateToken(token auth.UserToken, clientIP, userAgent string) (bool, error) {
+func (s *UserAuthTokenServiceImpl) TryRotateToken(token *auth.UserToken, clientIP, userAgent string) (bool, error) {
 	if token == nil {
 		return false, nil
 	}
 
-	model, err := extractModelFromToken(token)
-	if err != nil {
-		return false, err
-	}
+	model := userAuthTokenFromUserToken(token)
 
 	now := getTime()
 
@@ -191,21 +195,19 @@ func (s *UserAuthTokenServiceImpl) TryRotateToken(token auth.UserToken, clientIP
 	s.log.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
 	if affected > 0 {
 		model.UnhashedToken = newToken
+		model.toUserToken(token)
 		return true, nil
 	}
 
 	return false, nil
 }
 
-func (s *UserAuthTokenServiceImpl) RevokeToken(token auth.UserToken) error {
+func (s *UserAuthTokenServiceImpl) RevokeToken(token *auth.UserToken) error {
 	if token == nil {
 		return ErrAuthTokenNotFound
 	}
 
-	model, err := extractModelFromToken(token)
-	if err != nil {
-		return err
-	}
+	model := userAuthTokenFromUserToken(token)
 
 	rowsAffected, err := s.SQLStore.NewSession().Delete(model)
 	if err != nil {

+ 125 - 88
pkg/services/auth/authtoken/auth_token_test.go

@@ -1,12 +1,15 @@
 package authtoken
 
 import (
+	"encoding/json"
 	"testing"
 	"time"
 
+	"github.com/grafana/grafana/pkg/components/simplejson"
 	"github.com/grafana/grafana/pkg/setting"
 
 	"github.com/grafana/grafana/pkg/log"
+	"github.com/grafana/grafana/pkg/services/auth"
 	"github.com/grafana/grafana/pkg/services/sqlstore"
 	. "github.com/smartystreets/goconvey/convey"
 )
@@ -25,28 +28,24 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("When creating token", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
-			So(model.AuthTokenSeen, ShouldBeFalse)
+			So(userToken, ShouldNotBeNil)
+			So(userToken.AuthTokenSeen, ShouldBeFalse)
 
 			Convey("When lookup unhashed token should return user auth token", func() {
-				userToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
-				So(err, ShouldBeNil)
-				lookedUpModel, err := extractModelFromToken(userToken)
+				userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 				So(err, ShouldBeNil)
-				So(lookedUpModel, ShouldNotBeNil)
-				So(lookedUpModel.UserId, ShouldEqual, userID)
-				So(lookedUpModel.AuthTokenSeen, ShouldBeTrue)
+				So(userToken, ShouldNotBeNil)
+				So(userToken.UserId, ShouldEqual, userID)
+				So(userToken.AuthTokenSeen, ShouldBeTrue)
 
-				storedAuthToken, err := ctx.getAuthTokenByID(lookedUpModel.Id)
+				storedAuthToken, err := ctx.getAuthTokenByID(userToken.Id)
 				So(err, ShouldBeNil)
 				So(storedAuthToken, ShouldNotBeNil)
 				So(storedAuthToken.AuthTokenSeen, ShouldBeTrue)
 			})
 
 			Convey("When lookup hashed token should return user auth token not found error", func() {
-				userToken, err := userAuthTokenService.LookupToken(model.AuthToken)
+				userToken, err := userAuthTokenService.LookupToken(userToken.AuthToken)
 				So(err, ShouldEqual, ErrAuthTokenNotFound)
 				So(userToken, ShouldBeNil)
 			})
@@ -55,7 +54,7 @@ func TestUserAuthToken(t *testing.T) {
 				err = userAuthTokenService.RevokeToken(userToken)
 				So(err, ShouldBeNil)
 
-				model, err := ctx.getAuthTokenByID(model.Id)
+				model, err := ctx.getAuthTokenByID(userToken.Id)
 				So(err, ShouldBeNil)
 				So(model, ShouldBeNil)
 			})
@@ -66,10 +65,8 @@ func TestUserAuthToken(t *testing.T) {
 			})
 
 			Convey("revoking non-existing token should return error", func() {
-				model.Id = 1000
-				nonExistingToken, err := model.toUserToken()
-				So(err, ShouldBeNil)
-				err = userAuthTokenService.RevokeToken(nonExistingToken)
+				userToken.Id = 1000
+				err = userAuthTokenService.RevokeToken(userToken)
 				So(err, ShouldEqual, ErrAuthTokenNotFound)
 			})
 		})
@@ -77,17 +74,8 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("expires correctly", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
 
-			_, err = userAuthTokenService.LookupToken(model.UnhashedToken)
-			So(err, ShouldBeNil)
-
-			model, err = ctx.getAuthTokenByID(model.Id)
-			So(err, ShouldBeNil)
-
-			userToken, err = model.toUserToken()
+			userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 
 			getTime = func() time.Time {
@@ -98,14 +86,14 @@ func TestUserAuthToken(t *testing.T) {
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
-			_, err = userAuthTokenService.LookupToken(model.UnhashedToken)
+			userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 
-			stillGood, err := userAuthTokenService.LookupToken(model.UnhashedToken)
+			stillGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(stillGood, ShouldNotBeNil)
 
-			model, err = ctx.getAuthTokenByID(model.Id)
+			model, err := ctx.getAuthTokenByID(userToken.Id)
 			So(err, ShouldBeNil)
 
 			Convey("when rotated_at is 6:59:59 ago should find token", func() {
@@ -113,7 +101,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour).Add(-time.Second)
 				}
 
-				stillGood, err = userAuthTokenService.LookupToken(stillGood.GetToken())
+				stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(stillGood, ShouldNotBeNil)
 			})
@@ -123,7 +111,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour)
 				}
 
-				notGood, err := userAuthTokenService.LookupToken(userToken.GetToken())
+				notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 				So(err, ShouldEqual, ErrAuthTokenNotFound)
 				So(notGood, ShouldBeNil)
 			})
@@ -137,7 +125,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour).Add(-time.Second)
 				}
 
-				stillGood, err = userAuthTokenService.LookupToken(stillGood.GetToken())
+				stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(stillGood, ShouldNotBeNil)
 			})
@@ -151,7 +139,7 @@ func TestUserAuthToken(t *testing.T) {
 					return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour)
 				}
 
-				notGood, err := userAuthTokenService.LookupToken(userToken.GetToken())
+				notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 				So(err, ShouldEqual, ErrAuthTokenNotFound)
 				So(notGood, ShouldBeNil)
 			})
@@ -160,37 +148,35 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("can properly rotate tokens", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
 
-			prevToken := model.AuthToken
-			unhashedPrev := model.UnhashedToken
+			prevToken := userToken.AuthToken
+			unhashedPrev := userToken.UnhashedToken
 
 			rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeFalse)
 
-			updated, err := ctx.markAuthTokenAsSeen(model.Id)
+			updated, err := ctx.markAuthTokenAsSeen(userToken.Id)
 			So(err, ShouldBeNil)
 			So(updated, ShouldBeTrue)
 
-			model, err = ctx.getAuthTokenByID(model.Id)
-			So(err, ShouldBeNil)
-			tok, err := model.toUserToken()
+			model, err := ctx.getAuthTokenByID(userToken.Id)
 			So(err, ShouldBeNil)
 
+			var tok auth.UserToken
+			model.toUserToken(&tok)
+
 			getTime = func() time.Time {
 				return t.Add(time.Hour)
 			}
 
-			rotated, err = userAuthTokenService.TryRotateToken(tok, "192.168.10.12:1234", "a new user agent")
+			rotated, err = userAuthTokenService.TryRotateToken(&tok, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
-			unhashedToken := model.UnhashedToken
+			unhashedToken := tok.UnhashedToken
 
-			model, err = ctx.getAuthTokenByID(model.Id)
+			model, err = ctx.getAuthTokenByID(tok.Id)
 			So(err, ShouldBeNil)
 			model.UnhashedToken = unhashedToken
 
@@ -205,17 +191,15 @@ func TestUserAuthToken(t *testing.T) {
 
 			lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
 			So(err, ShouldBeNil)
-			lookedUpModel, err := extractModelFromToken(lookedUpUserToken)
-			So(err, ShouldBeNil)
-			So(lookedUpModel, ShouldNotBeNil)
-			So(lookedUpModel.AuthTokenSeen, ShouldBeTrue)
-			So(lookedUpModel.SeenAt, ShouldEqual, getTime().Unix())
+			So(lookedUpUserToken, ShouldNotBeNil)
+			So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
+			So(lookedUpUserToken.SeenAt, ShouldEqual, getTime().Unix())
 
 			lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev)
 			So(err, ShouldBeNil)
-			So(lookedUpModel, ShouldNotBeNil)
-			So(lookedUpModel.Id, ShouldEqual, model.Id)
-			So(lookedUpModel.AuthTokenSeen, ShouldBeTrue)
+			So(lookedUpUserToken, ShouldNotBeNil)
+			So(lookedUpUserToken.Id, ShouldEqual, model.Id)
+			So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
 
 			getTime = func() time.Time {
 				return t.Add(time.Hour + (2 * time.Minute))
@@ -223,12 +207,10 @@ func TestUserAuthToken(t *testing.T) {
 
 			lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev)
 			So(err, ShouldBeNil)
-			lookedUpModel, err = extractModelFromToken(lookedUpUserToken)
-			So(err, ShouldBeNil)
-			So(lookedUpModel, ShouldNotBeNil)
-			So(lookedUpModel.AuthTokenSeen, ShouldBeTrue)
+			So(lookedUpUserToken, ShouldNotBeNil)
+			So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
 
-			lookedUpModel, err = ctx.getAuthTokenByID(lookedUpModel.Id)
+			lookedUpModel, err := ctx.getAuthTokenByID(lookedUpUserToken.Id)
 			So(err, ShouldBeNil)
 			So(lookedUpModel, ShouldNotBeNil)
 			So(lookedUpModel.AuthTokenSeen, ShouldBeFalse)
@@ -237,7 +219,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
 
-			model, err = ctx.getAuthTokenByID(model.Id)
+			model, err = ctx.getAuthTokenByID(userToken.Id)
 			So(err, ShouldBeNil)
 			So(model, ShouldNotBeNil)
 			So(model.SeenAt, ShouldEqual, 0)
@@ -246,11 +228,9 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("keeps prev token valid for 1 minute after it is confirmed", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
+			So(userToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
+			lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUpUserToken, ShouldNotBeNil)
 
@@ -258,7 +238,7 @@ func TestUserAuthToken(t *testing.T) {
 				return t.Add(10 * time.Minute)
 			}
 
-			prevToken := model.UnhashedToken
+			prevToken := userToken.UnhashedToken
 			rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox")
 			So(err, ShouldBeNil)
 			So(rotated, ShouldBeTrue)
@@ -267,7 +247,7 @@ func TestUserAuthToken(t *testing.T) {
 				return t.Add(20 * time.Minute)
 			}
 
-			currentUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
+			currentUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(currentUserToken, ShouldNotBeNil)
 
@@ -279,23 +259,17 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("will not mark token unseen when prev and current are the same", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
+			So(userToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken)
-			So(err, ShouldBeNil)
-			lookedUpModel, err := extractModelFromToken(lookedUpUserToken)
+			lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
-			So(lookedUpModel, ShouldNotBeNil)
+			So(lookedUpUserToken, ShouldNotBeNil)
 
-			lookedUpUserToken, err = userAuthTokenService.LookupToken(model.UnhashedToken)
+			lookedUpUserToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken)
 			So(err, ShouldBeNil)
-			lookedUpModel, err = extractModelFromToken(lookedUpUserToken)
-			So(err, ShouldBeNil)
-			So(lookedUpModel, ShouldNotBeNil)
+			So(lookedUpUserToken, ShouldNotBeNil)
 
-			lookedUpModel, err = ctx.getAuthTokenByID(lookedUpModel.Id)
+			lookedUpModel, err := ctx.getAuthTokenByID(lookedUpUserToken.Id)
 			So(err, ShouldBeNil)
 			So(lookedUpModel, ShouldNotBeNil)
 			So(lookedUpModel.AuthTokenSeen, ShouldBeTrue)
@@ -304,14 +278,12 @@ func TestUserAuthToken(t *testing.T) {
 		Convey("Rotate token", func() {
 			userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
-			model, err := extractModelFromToken(userToken)
-			So(err, ShouldBeNil)
-			So(model, ShouldNotBeNil)
+			So(userToken, ShouldNotBeNil)
 
-			prevToken := model.AuthToken
+			prevToken := userToken.AuthToken
 
 			Convey("Should rotate current token and previous token when auth token seen", func() {
-				updated, err := ctx.markAuthTokenAsSeen(model.Id)
+				updated, err := ctx.markAuthTokenAsSeen(userToken.Id)
 				So(err, ShouldBeNil)
 				So(updated, ShouldBeTrue)
 
@@ -323,7 +295,7 @@ func TestUserAuthToken(t *testing.T) {
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 
-				storedToken, err := ctx.getAuthTokenByID(model.Id)
+				storedToken, err := ctx.getAuthTokenByID(userToken.Id)
 				So(err, ShouldBeNil)
 				So(storedToken, ShouldNotBeNil)
 				So(storedToken.AuthTokenSeen, ShouldBeFalse)
@@ -332,7 +304,7 @@ func TestUserAuthToken(t *testing.T) {
 
 				prevToken = storedToken.AuthToken
 
-				updated, err = ctx.markAuthTokenAsSeen(model.Id)
+				updated, err = ctx.markAuthTokenAsSeen(userToken.Id)
 				So(err, ShouldBeNil)
 				So(updated, ShouldBeTrue)
 
@@ -344,7 +316,7 @@ func TestUserAuthToken(t *testing.T) {
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 
-				storedToken, err = ctx.getAuthTokenByID(model.Id)
+				storedToken, err = ctx.getAuthTokenByID(userToken.Id)
 				So(err, ShouldBeNil)
 				So(storedToken, ShouldNotBeNil)
 				So(storedToken.AuthTokenSeen, ShouldBeFalse)
@@ -353,7 +325,7 @@ func TestUserAuthToken(t *testing.T) {
 			})
 
 			Convey("Should rotate current token, but keep previous token when auth token not seen", func() {
-				model.RotatedAt = getTime().Add(-2 * time.Minute).Unix()
+				userToken.RotatedAt = getTime().Add(-2 * time.Minute).Unix()
 
 				getTime = func() time.Time {
 					return t.Add(2 * time.Minute)
@@ -363,7 +335,7 @@ func TestUserAuthToken(t *testing.T) {
 				So(err, ShouldBeNil)
 				So(rotated, ShouldBeTrue)
 
-				storedToken, err := ctx.getAuthTokenByID(model.Id)
+				storedToken, err := ctx.getAuthTokenByID(userToken.Id)
 				So(err, ShouldBeNil)
 				So(storedToken, ShouldNotBeNil)
 				So(storedToken.AuthTokenSeen, ShouldBeFalse)
@@ -372,6 +344,71 @@ func TestUserAuthToken(t *testing.T) {
 			})
 		})
 
+		Convey("When populating userAuthToken from UserToken should copy all properties", func() {
+			ut := auth.UserToken{
+				Id:            1,
+				UserId:        2,
+				AuthToken:     "a",
+				PrevAuthToken: "b",
+				UserAgent:     "c",
+				ClientIp:      "d",
+				AuthTokenSeen: true,
+				SeenAt:        3,
+				RotatedAt:     4,
+				CreatedAt:     5,
+				UpdatedAt:     6,
+				UnhashedToken: "e",
+			}
+			utBytes, err := json.Marshal(ut)
+			So(err, ShouldBeNil)
+			utJSON, err := simplejson.NewJson(utBytes)
+			So(err, ShouldBeNil)
+			utMap := utJSON.MustMap()
+
+			var uat userAuthToken
+			uat.fromUserToken(&ut)
+			uatBytes, err := json.Marshal(uat)
+			So(err, ShouldBeNil)
+			uatJSON, err := simplejson.NewJson(uatBytes)
+			So(err, ShouldBeNil)
+			uatMap := uatJSON.MustMap()
+
+			So(uatMap, ShouldResemble, utMap)
+		})
+
+		Convey("When populating userToken from userAuthToken should copy all properties", func() {
+			uat := userAuthToken{
+				Id:            1,
+				UserId:        2,
+				AuthToken:     "a",
+				PrevAuthToken: "b",
+				UserAgent:     "c",
+				ClientIp:      "d",
+				AuthTokenSeen: true,
+				SeenAt:        3,
+				RotatedAt:     4,
+				CreatedAt:     5,
+				UpdatedAt:     6,
+				UnhashedToken: "e",
+			}
+			uatBytes, err := json.Marshal(uat)
+			So(err, ShouldBeNil)
+			uatJSON, err := simplejson.NewJson(uatBytes)
+			So(err, ShouldBeNil)
+			uatMap := uatJSON.MustMap()
+
+			var ut auth.UserToken
+			err = uat.toUserToken(&ut)
+			So(err, ShouldBeNil)
+			utBytes, err := json.Marshal(ut)
+			So(err, ShouldBeNil)
+			utJSON, err := simplejson.NewJson(utBytes)
+			So(err, ShouldBeNil)
+			utMap := utJSON.MustMap()
+
+			So(utMap, ShouldResemble, uatMap)
+		})
+
 		Reset(func() {
 			getTime = time.Now
 		})

+ 38 - 36
pkg/services/auth/authtoken/model.go

@@ -27,50 +27,52 @@ type userAuthToken struct {
 	UnhashedToken string `xorm:"-"`
 }
 
-func (uat *userAuthToken) toUserToken() (auth.UserToken, error) {
-	if uat == nil {
-		return nil, fmt.Errorf("needs pointer to userAuthToken struct")
-	}
-
-	return &userTokenImpl{
-		userAuthToken: uat,
-	}, nil
-}
-
-type userToken interface {
-	auth.UserToken
-	GetModel() *userAuthToken
-}
-
-type userTokenImpl struct {
-	*userAuthToken
+func userAuthTokenFromUserToken(ut *auth.UserToken) *userAuthToken {
+	var uat userAuthToken
+	uat.fromUserToken(ut)
+	return &uat
 }
 
-func (ut *userTokenImpl) GetUserId() int64 {
-	return ut.UserId
+func (uat *userAuthToken) fromUserToken(ut *auth.UserToken) {
+	uat.Id = ut.Id
+	uat.UserId = ut.UserId
+	uat.AuthToken = ut.AuthToken
+	uat.PrevAuthToken = ut.PrevAuthToken
+	uat.UserAgent = ut.UserAgent
+	uat.ClientIp = ut.ClientIp
+	uat.AuthTokenSeen = ut.AuthTokenSeen
+	uat.SeenAt = ut.SeenAt
+	uat.RotatedAt = ut.RotatedAt
+	uat.CreatedAt = ut.CreatedAt
+	uat.UpdatedAt = ut.UpdatedAt
+	uat.UnhashedToken = ut.UnhashedToken
 }
 
-func (ut *userTokenImpl) GetToken() string {
-	return ut.UnhashedToken
-}
-
-func (ut *userTokenImpl) GetModel() *userAuthToken {
-	return ut.userAuthToken
-}
-
-func extractModelFromToken(token auth.UserToken) (*userAuthToken, error) {
-	ut, ok := token.(userToken)
-	if !ok {
-		return nil, fmt.Errorf("failed to cast token")
+func (uat *userAuthToken) toUserToken(ut *auth.UserToken) error {
+	if uat == nil {
+		return fmt.Errorf("needs pointer to userAuthToken struct")
 	}
 
-	return ut.GetModel(), nil
+	ut.Id = uat.Id
+	ut.UserId = uat.UserId
+	ut.AuthToken = uat.AuthToken
+	ut.PrevAuthToken = uat.PrevAuthToken
+	ut.UserAgent = uat.UserAgent
+	ut.ClientIp = uat.ClientIp
+	ut.AuthTokenSeen = uat.AuthTokenSeen
+	ut.SeenAt = uat.SeenAt
+	ut.RotatedAt = uat.RotatedAt
+	ut.CreatedAt = uat.CreatedAt
+	ut.UpdatedAt = uat.UpdatedAt
+	ut.UnhashedToken = uat.UnhashedToken
+
+	return nil
 }
 
 // UserAuthTokenService are used for generating and validating user auth tokens
 type UserAuthTokenService interface {
-	CreateToken(userId int64, clientIP, userAgent string) (auth.UserToken, error)
-	LookupToken(unhashedToken string) (auth.UserToken, error)
-	TryRotateToken(token auth.UserToken, clientIP, userAgent string) (bool, error)
-	RevokeToken(token auth.UserToken) error
+	CreateToken(userId int64, clientIP, userAgent string) (*auth.UserToken, error)
+	LookupToken(unhashedToken string) (*auth.UserToken, error)
+	TryRotateToken(token *auth.UserToken, clientIP, userAgent string) (bool, error)
+	RevokeToken(token *auth.UserToken) error
 }