浏览代码

inital code for rotate

bergquist 7 年之前
父节点
当前提交
c2accfa4c0
共有 5 个文件被更改,包括 141 次插入49 次删除
  1. 32 3
      pkg/middleware/middleware.go
  2. 18 3
      pkg/models/context.go
  3. 65 18
      pkg/services/auth/auth_token.go
  4. 12 11
      pkg/services/auth/auth_token_test.go
  5. 14 14
      pkg/services/auth/model.go

+ 32 - 3
pkg/middleware/middleware.go

@@ -1,7 +1,10 @@
 package middleware
 
 import (
+	"net/http"
+	"net/url"
 	"strconv"
+	"time"
 
 	"github.com/grafana/grafana/pkg/bus"
 	"github.com/grafana/grafana/pkg/components/apikeygen"
@@ -11,7 +14,7 @@ import (
 	"github.com/grafana/grafana/pkg/services/session"
 	"github.com/grafana/grafana/pkg/setting"
 	"github.com/grafana/grafana/pkg/util"
-	"gopkg.in/macaron.v1"
+	macaron "gopkg.in/macaron.v1"
 )
 
 var (
@@ -62,7 +65,27 @@ func GetContextHandler(ats *auth.UserAuthTokenService) macaron.Handler {
 		c.Next()
 
 		//if signed in with token
-		//ats.RefreshToken()
+		rotated, err := ats.RefreshToken(ctx.UserToken, ctx.RemoteAddr(), ctx.Req.UserAgent())
+		if err != nil {
+			ctx.Logger.Error("failed to rotate token", "error", err)
+			return
+		}
+
+		if rotated {
+			ctx.Logger.Info("new token", "unhashed token", ctx.UserToken.UnhashedToken)
+			//c.SetCookie("grafana_session", url.QueryEscape(ctx.UserToken.UnhashedToken), nil, setting.AppSubUrl+"/", setting.Domain, false, true)
+			// ctx.Resp.Header().Del("Set-Cookie")
+			cookie := http.Cookie{
+				Name:     "grafana_session",
+				Value:    url.QueryEscape(ctx.UserToken.UnhashedToken),
+				HttpOnly: true,
+				MaxAge:   int(time.Minute * 10),
+				Domain:   setting.Domain,
+				Path:     setting.AppSubUrl + "/",
+			}
+
+			ctx.Resp.Header().Add("Set-Cookie", cookie.String())
+		}
 
 		// update last seen every 5min
 		if ctx.ShouldUpdateLastSeenAt() {
@@ -95,7 +118,12 @@ func initContextWithAnonymousUser(ctx *m.ReqContext) bool {
 }
 
 func initContextWithToken(ctx *m.ReqContext, orgID int64, ts *auth.UserAuthTokenService) bool {
-	user, err := ts.LookupToken(ctx)
+	unhashedToken := ctx.GetCookie("grafana_session")
+	if unhashedToken == "" {
+		return false
+	}
+
+	user, err := ts.LookupToken(unhashedToken)
 	if err != nil {
 		ctx.Logger.Info("failed to look up user based on cookie")
 		return false
@@ -109,6 +137,7 @@ func initContextWithToken(ctx *m.ReqContext, orgID int64, ts *auth.UserAuthToken
 
 	ctx.SignedInUser = query.Result
 	ctx.IsSignedIn = true
+	ctx.UserToken = user
 
 	return true
 }

+ 18 - 3
pkg/models/context.go

@@ -3,17 +3,32 @@ package models
 import (
 	"strings"
 
-	"github.com/prometheus/client_golang/prometheus"
-	"gopkg.in/macaron.v1"
-
 	"github.com/grafana/grafana/pkg/log"
 	"github.com/grafana/grafana/pkg/services/session"
 	"github.com/grafana/grafana/pkg/setting"
+	"github.com/prometheus/client_golang/prometheus"
+	"gopkg.in/macaron.v1"
 )
 
+type UserAuthToken struct {
+	Id            int64
+	UserId        int64
+	AuthToken     string
+	PrevAuthToken string
+	UserAgent     string
+	ClientIp      string
+	AuthTokenSeen bool
+	SeenAt        int64
+	RotatedAt     int64
+	CreatedAt     int64
+	UpdatedAt     int64
+	UnhashedToken string `xorm:"-"`
+}
+
 type ReqContext struct {
 	*macaron.Context
 	*SignedInUser
+	UserToken *UserAuthToken
 
 	Session session.SessionStore
 

+ 65 - 18
pkg/services/auth/auth_token.go

@@ -3,7 +3,6 @@ package auth
 import (
 	"crypto/sha256"
 	"encoding/hex"
-	"fmt"
 	"net/http"
 	"net/url"
 	"time"
@@ -45,10 +44,11 @@ func (s *UserAuthTokenService) UserAuthenticatedHook(user *models.User, c *model
 	c.Resp.Header().Del("Set-Cookie")
 	cookie := http.Cookie{
 		Name:     sessionCookieKey,
-		Value:    url.QueryEscape(userToken.unhashedToken),
+		Value:    url.QueryEscape(userToken.UnhashedToken),
 		HttpOnly: true,
-		Expires:  time.Now().Add(time.Minute * 10),
+		MaxAge:   int(time.Minute * 10),
 		Domain:   setting.Domain,
+		Path:     setting.AppSubUrl + "/",
 	}
 
 	c.Resp.Header().Add("Set-Cookie", cookie.String())
@@ -57,7 +57,18 @@ func (s *UserAuthTokenService) UserAuthenticatedHook(user *models.User, c *model
 }
 
 func (s *UserAuthTokenService) UserSignedOutHook(c *models.ReqContext) {
-	c.SetCookie(sessionCookieKey, "", -1, setting.AppSubUrl+"/", setting.Domain, false, true)
+	//c.SetCookie(sessionCookieKey, "", -1, setting.AppSubUrl+"/", setting.Domain, false, true)
+	c.Resp.Header().Del("Set-Cookie")
+	cookie := http.Cookie{
+		Name:     sessionCookieKey,
+		Value:    "",
+		HttpOnly: true,
+		MaxAge:   -1,
+		Domain:   setting.Domain,
+		Path:     setting.AppSubUrl + "/",
+	}
+
+	c.Resp.Header().Add("Set-Cookie", cookie.String())
 }
 
 // func (s *UserAuthTokenService) RequestMiddleware() macaron.Handler {
@@ -82,7 +93,7 @@ func (s *UserAuthTokenService) UserSignedOutHook(c *models.ReqContext) {
 // 	}
 // }
 
-func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*userAuthToken, error) {
+func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserAuthToken, error) {
 	clientIP = util.ParseIPAddress(clientIP)
 	token, err := util.RandomHex(16)
 	if err != nil {
@@ -91,7 +102,7 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
 
 	hashedToken := hashToken(token)
 
-	userToken := userAuthToken{
+	userToken := models.UserAuthToken{
 		UserId:        userId,
 		AuthToken:     hashedToken,
 		PrevAuthToken: hashedToken,
@@ -108,20 +119,15 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
 		return nil, err
 	}
 
-	userToken.unhashedToken = token
+	userToken.UnhashedToken = token
 
 	return &userToken, nil
 }
 
-func (s *UserAuthTokenService) LookupToken(ctx *models.ReqContext) (*userAuthToken, error) {
-	unhashedToken := ctx.GetCookie(sessionCookieKey)
-	if unhashedToken == "" {
-		return nil, fmt.Errorf("session token cookie is empty")
-	}
-
+func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserAuthToken, error) {
 	hashedToken := hashToken(unhashedToken)
 
-	var userToken userAuthToken
+	var userToken models.UserAuthToken
 	exists, err := s.SQLStore.NewSession().Where("auth_token = ? OR prev_auth_token = ?", hashedToken, hashedToken).Get(&userToken)
 	if err != nil {
 		return nil, err
@@ -166,14 +172,55 @@ func (s *UserAuthTokenService) LookupToken(ctx *models.ReqContext) (*userAuthTok
 		}
 	}
 
-	userToken.unhashedToken = unhashedToken
+	userToken.UnhashedToken = unhashedToken
 
 	return &userToken, nil
 }
 
-func (s *UserAuthTokenService) RefreshToken(token *userAuthToken, clientIP, userAgent string) (bool, error) {
-	// lookup token in db
-	// refresh token if needed
+func (s *UserAuthTokenService) RefreshToken(token *models.UserAuthToken, clientIP, userAgent string) (bool, error) {
+	if token == nil {
+		return false, nil
+	}
+
+	var needsRotation = false
+	rotatedAt := time.Unix(token.RotatedAt, 0)
+	if token.AuthTokenSeen {
+		needsRotation = rotatedAt.Before(now().Add(time.Duration(-1) * time.Minute))
+	} else {
+		needsRotation = rotatedAt.Before(now().Add(time.Duration(-30) * time.Second))
+	}
+
+	s.log.Info("refresh token", "needs rotation?", needsRotation, "auth_token_seen", token.AuthTokenSeen, "rotated_at", rotatedAt, "token.Id", token.Id)
+	if !needsRotation {
+		return false, nil
+	}
+
+	newToken, _ := util.RandomHex(16)
+	hashedToken := hashToken(newToken)
+
+	sql := `
+		UPDATE user_auth_token
+		SET
+			auth_token_seen = false,
+			seen_at = null,
+			user_agent = ?,
+			client_ip = ?,
+			prev_auth_token = case when auth_token_seen then auth_token else prev_auth_token end,
+			auth_token = ?,
+			rotated_at = ?
+		WHERE id = ? AND (auth_token_seen or rotated_at < ?)`
+
+	res, err := s.SQLStore.NewSession().Exec(sql, userAgent, clientIP, hashedToken, now().Unix(), token.Id, now().Add(time.Duration(-30)*time.Second))
+	if err != nil {
+		return false, err
+	}
+
+	affected, _ := res.RowsAffected()
+	s.log.Info("rotated", "affected", affected, "auth_token_id", token.Id, "userId", token.UserId, "user_agent", userAgent, "client_ip", clientIP)
+	if affected > 0 {
+		token.UnhashedToken = newToken
+		return true, nil
+	}
 
 	return false, nil
 }

+ 12 - 11
pkg/services/auth/auth_token_test.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/grafana/grafana/pkg/log"
+	"github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/services/sqlstore"
 	. "github.com/smartystreets/goconvey/convey"
 )
@@ -27,7 +28,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(token.AuthTokenSeen, ShouldBeFalse)
 
 			Convey("When lookup unhashed token should return user auth token", func() {
-				LookupToken, err := userAuthTokenService.LookupToken(token.unhashedToken)
+				LookupToken, err := userAuthTokenService.LookupToken(token.UnhashedToken)
 				So(err, ShouldBeNil)
 				So(LookupToken, ShouldNotBeNil)
 				So(LookupToken.UserId, ShouldEqual, userID)
@@ -51,7 +52,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(err, ShouldBeNil)
 			So(token, ShouldNotBeNil)
 
-			_, err = userAuthTokenService.LookupToken(token.unhashedToken)
+			_, err = userAuthTokenService.LookupToken(token.UnhashedToken)
 			So(err, ShouldBeNil)
 
 			token, err = ctx.getAuthTokenByID(token.Id)
@@ -61,15 +62,15 @@ func TestUserAuthToken(t *testing.T) {
 			_, err = userAuthTokenService.RefreshToken(token, "192.168.10.11:1234", "some user agent")
 			So(err, ShouldBeNil)
 
-			_, err = userAuthTokenService.LookupToken(token.unhashedToken)
+			_, err = userAuthTokenService.LookupToken(token.UnhashedToken)
 			So(err, ShouldBeNil)
 
-			stillGood, err := userAuthTokenService.LookupToken(token.unhashedToken)
+			stillGood, err := userAuthTokenService.LookupToken(token.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(stillGood, ShouldNotBeNil)
 
 			// set now (new - 2 hours)
-			notGood, err := userAuthTokenService.LookupToken(token.unhashedToken)
+			notGood, err := userAuthTokenService.LookupToken(token.UnhashedToken)
 			So(err, ShouldEqual, ErrAuthTokenNotFound)
 			So(notGood, ShouldBeNil)
 		})
@@ -80,7 +81,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(token, ShouldNotBeNil)
 
 			prevToken := token.AuthToken
-			unhashedPrev := token.unhashedToken
+			unhashedPrev := token.UnhashedToken
 
 			refreshed, err := userAuthTokenService.RefreshToken(token, "192.168.10.12:1234", "a new user agent")
 			So(err, ShouldBeNil)
@@ -99,11 +100,11 @@ func TestUserAuthToken(t *testing.T) {
 			So(err, ShouldBeNil)
 			So(refreshed, ShouldBeTrue)
 
-			unhashedToken := token.unhashedToken
+			unhashedToken := token.UnhashedToken
 
 			token, err = ctx.getAuthTokenByID(token.Id)
 			So(err, ShouldBeNil)
-			token.unhashedToken = unhashedToken
+			token.UnhashedToken = unhashedToken
 
 			So(token.RotatedAt, ShouldEqual, t.Unix())
 			So(token.ClientIp, ShouldEqual, "192.168.10.12")
@@ -112,7 +113,7 @@ func TestUserAuthToken(t *testing.T) {
 			So(token.SeenAt, ShouldEqual, 0)
 			So(token.PrevAuthToken, ShouldEqual, prevToken)
 
-			lookedUp, err := userAuthTokenService.LookupToken(token.unhashedToken)
+			lookedUp, err := userAuthTokenService.LookupToken(token.UnhashedToken)
 			So(err, ShouldBeNil)
 			So(lookedUp, ShouldNotBeNil)
 			So(lookedUp.AuthTokenSeen, ShouldBeTrue)
@@ -180,9 +181,9 @@ type testContext struct {
 	tokenService *UserAuthTokenService
 }
 
-func (c *testContext) getAuthTokenByID(id int64) (*userAuthToken, error) {
+func (c *testContext) getAuthTokenByID(id int64) (*models.UserAuthToken, error) {
 	sess := c.sqlstore.NewSession()
-	var t userAuthToken
+	var t models.UserAuthToken
 	found, err := sess.ID(id).Get(&t)
 	if err != nil || !found {
 		return nil, err

+ 14 - 14
pkg/services/auth/model.go

@@ -9,17 +9,17 @@ var (
 	ErrAuthTokenNotFound = errors.New("User auth token not found")
 )
 
-type userAuthToken struct {
-	Id            int64
-	UserId        int64
-	AuthToken     string
-	PrevAuthToken string
-	UserAgent     string
-	ClientIp      string
-	AuthTokenSeen bool
-	SeenAt        int64
-	RotatedAt     int64
-	CreatedAt     int64
-	UpdatedAt     int64
-	unhashedToken string `xorm:"-"`
-}
+// type userAuthToken struct {
+// 	Id            int64
+// 	UserId        int64
+// 	AuthToken     string
+// 	PrevAuthToken string
+// 	UserAgent     string
+// 	ClientIp      string
+// 	AuthTokenSeen bool
+// 	SeenAt        int64
+// 	RotatedAt     int64
+// 	CreatedAt     int64
+// 	UpdatedAt     int64
+// 	unhashedToken string `xorm:"-"`
+// }