Преглед на файлове

begin user auth token implementation

Marcus Efraimsson преди 7 години
родител
ревизия
b0df7280be

+ 170 - 0
pkg/services/auth/auth_token.go

@@ -0,0 +1,170 @@
+package auth
+
+import (
+	"crypto/sha256"
+	"encoding/hex"
+	"time"
+
+	"github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/util"
+	macaron "gopkg.in/macaron.v1"
+
+	"github.com/grafana/grafana/pkg/log"
+	"github.com/grafana/grafana/pkg/registry"
+	"github.com/grafana/grafana/pkg/services/sqlstore"
+)
+
+func init() {
+	registry.RegisterService(&UserAuthTokenService{})
+}
+
+var now = time.Now
+
+// UserAuthTokenService are used for generating and validating user auth tokens
+type UserAuthTokenService struct {
+	SQLStore *sqlstore.SqlStore `inject:""`
+	log      log.Logger
+}
+
+// Init this service
+func (s *UserAuthTokenService) Init() error {
+	s.log = log.New("auth")
+	return nil
+}
+
+const sessionCookieKey = "grafana_session"
+
+func (s *UserAuthTokenService) UserAuthenticatedHook(user *models.User, c *models.ReqContext) error {
+	userToken, err := s.CreateToken(user.Id, c.RemoteAddr(), c.Req.UserAgent())
+	if err != nil {
+		return err
+	}
+
+	c.Resp.Header().Del("Set-Cookie")
+	c.SetCookie(sessionCookieKey, userToken.unhashedToken, setting.AppSubUrl+"/", setting.Domain, false, true)
+
+	return nil
+}
+
+func (s *UserAuthTokenService) UserSignedOutHook(c *models.ReqContext) {
+	c.SetCookie(sessionCookieKey, "", -1, setting.AppSubUrl+"/", setting.Domain, false, true)
+}
+
+func (s *UserAuthTokenService) RequestMiddleware() macaron.Handler {
+	return func(ctx *models.ReqContext) {
+		authToken := ctx.GetCookie(sessionCookieKey)
+		userToken, err := s.lookupToken(authToken)
+		if err != nil {
+
+		}
+
+		ctx.Next()
+
+		refreshed, err := s.refreshToken(userToken, ctx.RemoteAddr(), ctx.Req.UserAgent())
+		if err != nil {
+
+		}
+
+		if refreshed {
+			ctx.Resp.Header().Del("Set-Cookie")
+			ctx.SetCookie(sessionCookieKey, userToken.unhashedToken, setting.AppSubUrl+"/", setting.Domain, false, true)
+		}
+	}
+}
+
+func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*userAuthToken, error) {
+	clientIP = util.ParseIPAddress(clientIP)
+	token, err := util.RandomHex(16)
+	if err != nil {
+		return nil, err
+	}
+
+	hashedToken := hashToken(token)
+
+	userToken := userAuthToken{
+		UserId:        userId,
+		AuthToken:     hashedToken,
+		PrevAuthToken: hashedToken,
+		ClientIp:      clientIP,
+		UserAgent:     userAgent,
+		RotatedAt:     now().Unix(),
+		CreatedAt:     now().Unix(),
+		UpdatedAt:     now().Unix(),
+		SeenAt:        0,
+		AuthTokenSeen: false,
+	}
+	_, err = s.SQLStore.NewSession().Insert(&userToken)
+	if err != nil {
+		return nil, err
+	}
+
+	userToken.unhashedToken = token
+
+	return &userToken, nil
+}
+
+func (s *UserAuthTokenService) lookupToken(unhashedToken string) (*userAuthToken, error) {
+	hashedToken := hashToken(unhashedToken)
+
+	var userToken userAuthToken
+	exists, err := s.SQLStore.NewSession().Where("auth_token = ? OR prev_auth_token = ?", hashedToken, hashedToken).Get(&userToken)
+	if err != nil {
+		return nil, err
+	}
+
+	if !exists {
+		return nil, ErrAuthTokenNotFound
+	}
+
+	if userToken.AuthToken != hashedToken && userToken.PrevAuthToken == hashedToken && userToken.AuthTokenSeen {
+		userToken.AuthTokenSeen = false
+		expireBefore := now().Add(-1 * time.Minute).Unix()
+		affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND prev_auth_token = ? AND rotated_at < ?", userToken.Id, userToken.PrevAuthToken, expireBefore).AllCols().Update(&userToken)
+		if err != nil {
+			return nil, err
+		}
+
+		if affectedRows == 0 {
+			s.log.Debug("prev seen token unchanged", "userTokenId", userToken.Id, "userId", userToken.UserId, "authToken", userToken.AuthToken, "clientIP", userToken.ClientIp, "userAgent", userToken.UserAgent)
+		} else {
+			s.log.Debug("prev seen token", "userTokenId", userToken.Id, "userId", userToken.UserId, "authToken", userToken.AuthToken, "clientIP", userToken.ClientIp, "userAgent", userToken.UserAgent)
+		}
+	}
+
+	if !userToken.AuthTokenSeen && userToken.AuthToken == hashedToken {
+		userTokenCopy := userToken
+		userTokenCopy.AuthTokenSeen = true
+		userTokenCopy.SeenAt = now().Unix()
+		affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND auth_token = ?", userTokenCopy.Id, userTokenCopy.AuthToken).AllCols().Update(&userTokenCopy)
+		if err != nil {
+			return nil, err
+		}
+
+		if affectedRows == 1 {
+			userToken = userTokenCopy
+		}
+
+		if affectedRows == 0 {
+			s.log.Debug("seen wrong token", "userTokenId", userToken.Id, "userId", userToken.UserId, "authToken", userToken.AuthToken, "clientIP", userToken.ClientIp, "userAgent", userToken.UserAgent)
+		} else {
+			s.log.Debug("seen token", "userTokenId", userToken.Id, "userId", userToken.UserId, "authToken", userToken.AuthToken, "clientIP", userToken.ClientIp, "userAgent", userToken.UserAgent)
+		}
+	}
+
+	userToken.unhashedToken = unhashedToken
+
+	return &userToken, nil
+}
+
+func (s *UserAuthTokenService) refreshToken(token *userAuthToken, clientIP, userAgent string) (bool, error) {
+	// lookup token in db
+	// refresh token if needed
+
+	return false, nil
+}
+
+func hashToken(token string) string {
+	hashBytes := sha256.Sum256([]byte(token + setting.SecretKey))
+	return hex.EncodeToString(hashBytes[:])
+}

+ 206 - 0
pkg/services/auth/auth_token_test.go

@@ -0,0 +1,206 @@
+package auth
+
+import (
+	"testing"
+	"time"
+
+	"github.com/grafana/grafana/pkg/log"
+	"github.com/grafana/grafana/pkg/services/sqlstore"
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+func TestUserAuthToken(t *testing.T) {
+	Convey("Test user auth token", t, func() {
+		ctx := createTestContext(t)
+		userAuthTokenService := ctx.tokenService
+		userID := int64(10)
+
+		t := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC)
+		now = func() time.Time {
+			return t
+		}
+
+		Convey("When creating token", func() {
+			token, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			So(err, ShouldBeNil)
+			So(token, ShouldNotBeNil)
+			So(token.AuthTokenSeen, ShouldBeFalse)
+
+			Convey("When lookup unhashed token should return user auth token", func() {
+				lookupToken, err := userAuthTokenService.lookupToken(token.unhashedToken)
+				So(err, ShouldBeNil)
+				So(lookupToken, ShouldNotBeNil)
+				So(lookupToken.UserId, ShouldEqual, userID)
+				So(lookupToken.AuthTokenSeen, ShouldBeTrue)
+
+				storedAuthToken, err := ctx.getAuthTokenByID(lookupToken.Id)
+				So(err, ShouldBeNil)
+				So(storedAuthToken, ShouldNotBeNil)
+				So(storedAuthToken.AuthTokenSeen, ShouldBeTrue)
+			})
+
+			Convey("When lookup hashed token should return user auth token not found error", func() {
+				lookupToken, err := userAuthTokenService.lookupToken(token.AuthToken)
+				So(err, ShouldEqual, ErrAuthTokenNotFound)
+				So(lookupToken, ShouldBeNil)
+			})
+		})
+
+		Convey("expires correctly", func() {
+			token, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			So(err, ShouldBeNil)
+			So(token, ShouldNotBeNil)
+
+			_, err = userAuthTokenService.lookupToken(token.unhashedToken)
+			So(err, ShouldBeNil)
+
+			token, err = ctx.getAuthTokenByID(token.Id)
+			So(err, ShouldBeNil)
+
+			// set now (now - 23 hours)
+			_, err = userAuthTokenService.refreshToken(token, "192.168.10.11:1234", "some user agent")
+			So(err, ShouldBeNil)
+
+			_, err = userAuthTokenService.lookupToken(token.unhashedToken)
+			So(err, ShouldBeNil)
+
+			stillGood, err := userAuthTokenService.lookupToken(token.unhashedToken)
+			So(err, ShouldBeNil)
+			So(stillGood, ShouldNotBeNil)
+
+			// set now (new - 2 hours)
+			notGood, err := userAuthTokenService.lookupToken(token.unhashedToken)
+			So(err, ShouldEqual, ErrAuthTokenNotFound)
+			So(notGood, ShouldBeNil)
+		})
+
+		Convey("can properly rotate tokens", func() {
+			token, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
+			So(err, ShouldBeNil)
+			So(token, ShouldNotBeNil)
+
+			prevToken := token.AuthToken
+			unhashedPrev := token.unhashedToken
+
+			refreshed, err := userAuthTokenService.refreshToken(token, "192.168.10.12:1234", "a new user agent")
+			So(err, ShouldBeNil)
+			So(refreshed, ShouldBeFalse)
+
+			ctx.markAuthTokenAsSeen(token.Id)
+			token, err = ctx.getAuthTokenByID(token.Id)
+			So(err, ShouldBeNil)
+
+			// ability to auth using an old token
+			now = func() time.Time {
+				return t
+			}
+
+			refreshed, err = userAuthTokenService.refreshToken(token, "192.168.10.12:1234", "a new user agent")
+			So(err, ShouldBeNil)
+			So(refreshed, ShouldBeTrue)
+
+			unhashedToken := token.unhashedToken
+
+			token, err = ctx.getAuthTokenByID(token.Id)
+			So(err, ShouldBeNil)
+			token.unhashedToken = unhashedToken
+
+			So(token.RotatedAt, ShouldEqual, t.Unix())
+			So(token.ClientIp, ShouldEqual, "192.168.10.12")
+			So(token.UserAgent, ShouldEqual, "a new user agent")
+			So(token.AuthTokenSeen, ShouldBeFalse)
+			So(token.SeenAt, ShouldEqual, 0)
+			So(token.PrevAuthToken, ShouldEqual, prevToken)
+
+			lookedUp, err := userAuthTokenService.lookupToken(token.unhashedToken)
+			So(err, ShouldBeNil)
+			So(lookedUp, ShouldNotBeNil)
+			So(lookedUp.AuthTokenSeen, ShouldBeTrue)
+			So(lookedUp.SeenAt, ShouldEqual, t.Unix())
+
+			lookedUp, err = userAuthTokenService.lookupToken(unhashedPrev)
+			So(err, ShouldBeNil)
+			So(lookedUp, ShouldNotBeNil)
+			So(lookedUp.Id, ShouldEqual, token.Id)
+
+			now = func() time.Time {
+				return t.Add(2 * time.Minute)
+			}
+
+			lookedUp, err = userAuthTokenService.lookupToken(unhashedPrev)
+			So(err, ShouldBeNil)
+			So(lookedUp, ShouldNotBeNil)
+
+			lookedUp, err = ctx.getAuthTokenByID(lookedUp.Id)
+			So(err, ShouldBeNil)
+			So(lookedUp, ShouldNotBeNil)
+			So(lookedUp.AuthTokenSeen, ShouldBeFalse)
+
+			refreshed, err = userAuthTokenService.refreshToken(token, "192.168.10.12:1234", "a new user agent")
+			So(err, ShouldBeNil)
+			So(refreshed, ShouldBeTrue)
+
+			token, err = ctx.getAuthTokenByID(token.Id)
+			So(err, ShouldBeNil)
+			So(token, ShouldNotBeNil)
+			So(token.SeenAt, ShouldEqual, 0)
+		})
+
+		Convey("keeps prev token valid for 1 minute after it is confirmed", func() {
+
+		})
+
+		Convey("will not mark token unseen when prev and current are the same", func() {
+
+		})
+
+		Reset(func() {
+			now = time.Now
+		})
+	})
+}
+
+func createTestContext(t *testing.T) *testContext {
+	t.Helper()
+
+	sqlstore := sqlstore.InitTestDB(t)
+	tokenService := &UserAuthTokenService{
+		SQLStore: sqlstore,
+		log:      log.New("test-logger"),
+	}
+
+	return &testContext{
+		sqlstore:     sqlstore,
+		tokenService: tokenService,
+	}
+}
+
+type testContext struct {
+	sqlstore     *sqlstore.SqlStore
+	tokenService *UserAuthTokenService
+}
+
+func (c *testContext) getAuthTokenByID(id int64) (*userAuthToken, error) {
+	sess := c.sqlstore.NewSession()
+	var t userAuthToken
+	found, err := sess.ID(id).Get(&t)
+	if err != nil || !found {
+		return nil, err
+	}
+
+	return &t, nil
+}
+
+func (c *testContext) markAuthTokenAsSeen(id int64) (bool, error) {
+	sess := c.sqlstore.NewSession()
+	res, err := sess.Exec("UPDATE user_auth_token SET auth_token_seen = ? WHERE id = ?", c.sqlstore.Dialect.BooleanStr(true), id)
+	if err != nil {
+		return false, err
+	}
+
+	rowsAffected, err := res.RowsAffected()
+	if err != nil {
+		return false, err
+	}
+	return rowsAffected == 1, nil
+}

+ 25 - 0
pkg/services/auth/model.go

@@ -0,0 +1,25 @@
+package auth
+
+import (
+	"errors"
+)
+
+// Typed errors
+var (
+	ErrAuthTokenNotFound = errors.New("User auth token not found")
+)
+
+type userAuthToken struct {
+	Id            int64
+	UserId        int64
+	AuthToken     string
+	PrevAuthToken string
+	UserAgent     string
+	ClientIp      string
+	AuthTokenSeen bool
+	SeenAt        int64
+	RotatedAt     int64
+	CreatedAt     int64
+	UpdatedAt     int64
+	unhashedToken string `xorm:"-"`
+}

+ 1 - 0
pkg/services/sqlstore/migrations/migrations.go

@@ -32,6 +32,7 @@ func AddMigrations(mg *Migrator) {
 	addLoginAttemptMigrations(mg)
 	addUserAuthMigrations(mg)
 	addServerlockMigrations(mg)
+	addUserAuthTokenMigrations(mg)
 }
 
 func addMigrationLogMigrations(mg *Migrator) {

+ 32 - 0
pkg/services/sqlstore/migrations/user_auth_token_mig.go

@@ -0,0 +1,32 @@
+package migrations
+
+import (
+	. "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
+)
+
+func addUserAuthTokenMigrations(mg *Migrator) {
+	userAuthTokenV1 := Table{
+		Name: "user_auth_token",
+		Columns: []*Column{
+			{Name: "id", Type: DB_BigInt, IsPrimaryKey: true, IsAutoIncrement: true},
+			{Name: "user_id", Type: DB_BigInt, Nullable: false},
+			{Name: "auth_token", Type: DB_NVarchar, Length: 100, Nullable: false},
+			{Name: "prev_auth_token", Type: DB_NVarchar, Length: 100, Nullable: false},
+			{Name: "user_agent", Type: DB_NVarchar, Length: 255, Nullable: false},
+			{Name: "client_ip", Type: DB_NVarchar, Length: 255, Nullable: false},
+			{Name: "auth_token_seen", Type: DB_Bool, Nullable: false},
+			{Name: "seen_at", Type: DB_Int, Nullable: true},
+			{Name: "rotated_at", Type: DB_Int, Nullable: false},
+			{Name: "created_at", Type: DB_Int, Nullable: false},
+			{Name: "updated_at", Type: DB_Int, Nullable: false},
+		},
+		Indices: []*Index{
+			{Cols: []string{"auth_token"}, Type: UniqueIndex},
+			{Cols: []string{"prev_auth_token"}, Type: UniqueIndex},
+		},
+	}
+
+	mg.AddMigration("create user auth token table", NewAddTableMigration(userAuthTokenV1))
+	mg.AddMigration("add unique index user_auth_token.auth_token", NewAddIndexMigration(userAuthTokenV1, userAuthTokenV1.Indices[0]))
+	mg.AddMigration("add unique index user_auth_token.prev_auth_token", NewAddIndexMigration(userAuthTokenV1, userAuthTokenV1.Indices[1]))
+}