浏览代码

Auth: Do not search for the user twice (#18366)

* Auth: Do not search for the user twice

Previously `initContextWithBasicAuth` did not use `LoginUserQuery`, doing
`GetUserByLoginQuery` only i.e. looking user in DB only, things changed when
this function started to check LDAP provider via `LoginUserQuery` (#6940),
however, this request was placed after `GetUserByLoginQuery`, so we first
looking in DB then in the LDAP - if LDAP user hasn't logged in we will
not find it in DB, so `LoginUserQuery` will never be reached.

`LoginUserQuery` request already performs `GetUserByLoginQuery`
request in correct sequence. So we can just remove redundant request.

* Correct sequence execution during authentification &
  introduce tests for it

* Move basic auth tests to separate test file, since main test file already
  pretty large

* Introduce `testing.go` for the middleware module

* Remove redundant test helper function

* Make handler names more explicit

Ref 5777f65d05a8dc141c34e470ef1d5fe956f8173c
Fixes #18329

* Auth: address review comment
Oleg Gaidarenko 6 年之前
父节点
当前提交
d88fdc86fc
共有 4 个文件被更改,包括 249 次插入203 次删除
  1. 4 15
      pkg/middleware/middleware.go
  2. 143 0
      pkg/middleware/middleware_basic_auth_test.go
  3. 0 188
      pkg/middleware/middleware_test.go
  4. 102 0
      pkg/middleware/testing.go

+ 4 - 15
pkg/middleware/middleware.go

@@ -163,24 +163,11 @@ func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool {
 		return true
 	}
 
-	loginQuery := models.GetUserByLoginQuery{LoginOrEmail: username}
-	if err := bus.Dispatch(&loginQuery); err != nil {
-		ctx.Logger.Debug(
-			"Failed to look up the username",
-			"username", username,
-		)
-		ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err)
-
-		return true
-	}
-
-	user := loginQuery.Result
-	loginUserQuery := models.LoginUserQuery{
+	authQuery := models.LoginUserQuery{
 		Username: username,
 		Password: password,
-		User:     user,
 	}
-	if err := bus.Dispatch(&loginUserQuery); err != nil {
+	if err := bus.Dispatch(&authQuery); err != nil {
 		ctx.Logger.Debug(
 			"Failed to authorize the user",
 			"username", username,
@@ -190,6 +177,8 @@ func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool {
 		return true
 	}
 
+	user := authQuery.User
+
 	query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgId}
 	if err := bus.Dispatch(&query); err != nil {
 		ctx.Logger.Error(

+ 143 - 0
pkg/middleware/middleware_basic_auth_test.go

@@ -0,0 +1,143 @@
+package middleware
+
+import (
+	"encoding/json"
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+
+	"github.com/grafana/grafana/pkg/bus"
+	authLogin "github.com/grafana/grafana/pkg/login"
+	"github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/util"
+)
+
+func TestMiddlewareBasicAuth(t *testing.T) {
+	Convey("Given the basic auth", t, func() {
+		var oldBasicAuthEnabled = setting.BasicAuthEnabled
+		var oldDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection
+		var id int64 = 12
+
+		Convey("Setup", func() {
+			setting.BasicAuthEnabled = true
+			setting.DisableBruteForceLoginProtection = true
+			bus.ClearBusHandlers()
+		})
+
+		middlewareScenario(t, "Valid API key", func(sc *scenarioContext) {
+			var orgID int64 = 2
+			keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
+
+			bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
+				query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash}
+				return nil
+			})
+
+			authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
+			sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
+
+			Convey("Should return 200", func() {
+				So(sc.resp.Code, ShouldEqual, 200)
+			})
+
+			Convey("Should init middleware context", func() {
+				So(sc.context.IsSignedIn, ShouldEqual, true)
+				So(sc.context.OrgId, ShouldEqual, orgID)
+				So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR)
+			})
+		})
+
+		middlewareScenario(t, "Handle auth", func(sc *scenarioContext) {
+			var password = "MyPass"
+			var salt = "Salt"
+			var orgID int64 = 2
+
+			bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error {
+				query.User = &models.User{
+					Password: util.EncodePassword(password, salt),
+					Salt:     salt,
+				}
+				return nil
+			})
+
+			bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
+				query.Result = &models.SignedInUser{OrgId: orgID, UserId: id}
+				return nil
+			})
+
+			authHeader := util.GetBasicAuthHeader("myUser", password)
+			sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
+
+			Convey("Should init middleware context with users", func() {
+				So(sc.context.IsSignedIn, ShouldEqual, true)
+				So(sc.context.OrgId, ShouldEqual, orgID)
+				So(sc.context.UserId, ShouldEqual, id)
+			})
+
+			bus.ClearBusHandlers()
+		})
+
+		middlewareScenario(t, "Auth sequence", func(sc *scenarioContext) {
+			var password = "MyPass"
+			var salt = "Salt"
+
+			authLogin.Init()
+
+			bus.AddHandler("user-query", func(query *models.GetUserByLoginQuery) error {
+				query.Result = &models.User{
+					Password: util.EncodePassword(password, salt),
+					Id:       id,
+					Salt:     salt,
+				}
+				return nil
+			})
+
+			bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
+				query.Result = &models.SignedInUser{UserId: query.UserId}
+				return nil
+			})
+
+			authHeader := util.GetBasicAuthHeader("myUser", password)
+			sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
+
+			Convey("Should init middleware context with user", func() {
+				So(sc.context.IsSignedIn, ShouldEqual, true)
+				So(sc.context.UserId, ShouldEqual, id)
+			})
+		})
+
+		middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) {
+			sc.fakeReq("GET", "/")
+			sc.req.SetBasicAuth("user", "password")
+			sc.exec()
+
+			err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
+			So(err, ShouldNotBeNil)
+
+			So(sc.resp.Code, ShouldEqual, 401)
+			So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
+		})
+
+		middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) {
+			bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error {
+				return nil
+			})
+
+			sc.fakeReq("GET", "/")
+			sc.req.SetBasicAuth("killa", "gorilla")
+			sc.exec()
+
+			err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
+			So(err, ShouldNotBeNil)
+
+			So(sc.resp.Code, ShouldEqual, 401)
+			So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
+		})
+
+		Convey("Destroy", func() {
+			setting.BasicAuthEnabled = oldBasicAuthEnabled
+			setting.DisableBruteForceLoginProtection = oldDisableBruteForceLoginProtection
+		})
+	})
+}

+ 0 - 188
pkg/middleware/middleware_test.go

@@ -3,10 +3,8 @@ package middleware
 import (
 	"context"
 	"encoding/base32"
-	"encoding/json"
 	"fmt"
 	"net/http"
-	"net/http/httptest"
 	"path/filepath"
 	"testing"
 	"time"
@@ -476,95 +474,6 @@ func TestMiddlewareContext(t *testing.T) {
 	})
 }
 
-func TestMiddlewareBasicAuth(t *testing.T) {
-	Convey("Given the basic auth", t, func() {
-		old := setting.BasicAuthEnabled
-
-		Convey("Setup", func() {
-			setting.BasicAuthEnabled = true
-		})
-
-		middlewareScenario(t, "Valid API key", func(sc *scenarioContext) {
-			keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
-
-			bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
-				query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash}
-				return nil
-			})
-
-			authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
-			sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
-
-			Convey("Should return 200", func() {
-				So(sc.resp.Code, ShouldEqual, 200)
-			})
-
-			Convey("Should init middleware context", func() {
-				So(sc.context.IsSignedIn, ShouldEqual, true)
-				So(sc.context.OrgId, ShouldEqual, 12)
-				So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR)
-			})
-		})
-
-		middlewareScenario(t, "Handle auth", func(sc *scenarioContext) {
-
-			bus.AddHandler("test", func(query *models.GetUserByLoginQuery) error {
-				query.Result = &models.User{
-					Password: util.EncodePassword("myPass", "Salt"),
-					Salt:     "Salt",
-				}
-				return nil
-			})
-
-			bus.AddHandler("test", func(loginUserQuery *models.LoginUserQuery) error {
-				return nil
-			})
-
-			bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
-				query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
-				return nil
-			})
-
-			authHeader := util.GetBasicAuthHeader("myUser", "myPass")
-			sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
-
-			Convey("Should init middleware context with user", func() {
-				So(sc.context.IsSignedIn, ShouldEqual, true)
-				So(sc.context.OrgId, ShouldEqual, 2)
-				So(sc.context.UserId, ShouldEqual, 12)
-			})
-		})
-
-		middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) {
-			sc.fakeReqWithBasicAuth("GET", "/", "test", "test").exec()
-
-			err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
-			So(err, ShouldNotBeNil)
-
-			So(sc.resp.Code, ShouldEqual, 401)
-			So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
-		})
-
-		middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) {
-			bus.AddHandler("test", func(loginUserQuery *models.GetUserByLoginQuery) error {
-				return nil
-			})
-
-			sc.fakeReqWithBasicAuth("GET", "/", "test", "test").exec()
-
-			err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
-			So(err, ShouldNotBeNil)
-
-			So(sc.resp.Code, ShouldEqual, 401)
-			So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
-		})
-
-		Convey("Destroy", func() {
-			setting.BasicAuthEnabled = old
-		})
-	})
-}
-
 func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) {
 	Convey(desc, func() {
 		defer bus.ClearBusHandlers()
@@ -602,100 +511,3 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) {
 		fn(sc)
 	})
 }
-
-type scenarioContext struct {
-	m                    *macaron.Macaron
-	context              *models.ReqContext
-	resp                 *httptest.ResponseRecorder
-	apiKey               string
-	authHeader           string
-	tokenSessionCookie   string
-	respJson             map[string]interface{}
-	handlerFunc          handlerFunc
-	defaultHandler       macaron.Handler
-	url                  string
-	userAuthTokenService *auth.FakeUserAuthTokenService
-	remoteCacheService   *remotecache.RemoteCache
-
-	req *http.Request
-}
-
-func (sc *scenarioContext) withValidApiKey() *scenarioContext {
-	sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9"
-	return sc
-}
-
-func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext {
-	sc.tokenSessionCookie = unhashedToken
-	return sc
-}
-
-func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext {
-	sc.authHeader = authHeader
-	return sc
-}
-
-func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
-	sc.resp = httptest.NewRecorder()
-	req, err := http.NewRequest(method, url, nil)
-	So(err, ShouldBeNil)
-	sc.req = req
-
-	return sc
-}
-
-func (sc *scenarioContext) fakeReqWithBasicAuth(method, url, user, password string) *scenarioContext {
-	sc.resp = httptest.NewRecorder()
-	req, err := http.NewRequest(method, url, nil)
-	req.SetBasicAuth(user, password)
-	So(err, ShouldBeNil)
-	sc.req = req
-
-	return sc
-}
-
-func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext {
-	sc.resp = httptest.NewRecorder()
-	req, err := http.NewRequest(method, url, nil)
-	q := req.URL.Query()
-	for k, v := range queryParams {
-		q.Add(k, v)
-	}
-	req.URL.RawQuery = q.Encode()
-	So(err, ShouldBeNil)
-	sc.req = req
-
-	return sc
-}
-
-func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext {
-	sc.handlerFunc = fn
-	return sc
-}
-
-func (sc *scenarioContext) exec() {
-	if sc.apiKey != "" {
-		sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey)
-	}
-
-	if sc.authHeader != "" {
-		sc.req.Header.Add("Authorization", sc.authHeader)
-	}
-
-	if sc.tokenSessionCookie != "" {
-		sc.req.AddCookie(&http.Cookie{
-			Name:  setting.LoginCookieName,
-			Value: sc.tokenSessionCookie,
-		})
-	}
-
-	sc.m.ServeHTTP(sc.resp, sc.req)
-
-	if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
-		err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
-		So(err, ShouldBeNil)
-	}
-}
-
-type scenarioFunc func(c *scenarioContext)
-type handlerFunc func(c *models.ReqContext)

+ 102 - 0
pkg/middleware/testing.go

@@ -0,0 +1,102 @@
+package middleware
+
+import (
+	"encoding/json"
+	"net/http"
+	"net/http/httptest"
+
+	. "github.com/smartystreets/goconvey/convey"
+	"gopkg.in/macaron.v1"
+
+	"github.com/grafana/grafana/pkg/infra/remotecache"
+	"github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/services/auth"
+	"github.com/grafana/grafana/pkg/setting"
+)
+
+type scenarioContext struct {
+	m                    *macaron.Macaron
+	context              *models.ReqContext
+	resp                 *httptest.ResponseRecorder
+	apiKey               string
+	authHeader           string
+	tokenSessionCookie   string
+	respJson             map[string]interface{}
+	handlerFunc          handlerFunc
+	defaultHandler       macaron.Handler
+	url                  string
+	userAuthTokenService *auth.FakeUserAuthTokenService
+	remoteCacheService   *remotecache.RemoteCache
+
+	req *http.Request
+}
+
+func (sc *scenarioContext) withValidApiKey() *scenarioContext {
+	sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9"
+	return sc
+}
+
+func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext {
+	sc.tokenSessionCookie = unhashedToken
+	return sc
+}
+
+func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext {
+	sc.authHeader = authHeader
+	return sc
+}
+
+func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
+	sc.resp = httptest.NewRecorder()
+	req, err := http.NewRequest(method, url, nil)
+	So(err, ShouldBeNil)
+	sc.req = req
+
+	return sc
+}
+
+func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext {
+	sc.resp = httptest.NewRecorder()
+	req, err := http.NewRequest(method, url, nil)
+	q := req.URL.Query()
+	for k, v := range queryParams {
+		q.Add(k, v)
+	}
+	req.URL.RawQuery = q.Encode()
+	So(err, ShouldBeNil)
+	sc.req = req
+
+	return sc
+}
+
+func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext {
+	sc.handlerFunc = fn
+	return sc
+}
+
+func (sc *scenarioContext) exec() {
+	if sc.apiKey != "" {
+		sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey)
+	}
+
+	if sc.authHeader != "" {
+		sc.req.Header.Add("Authorization", sc.authHeader)
+	}
+
+	if sc.tokenSessionCookie != "" {
+		sc.req.AddCookie(&http.Cookie{
+			Name:  setting.LoginCookieName,
+			Value: sc.tokenSessionCookie,
+		})
+	}
+
+	sc.m.ServeHTTP(sc.resp, sc.req)
+
+	if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
+		err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
+		So(err, ShouldBeNil)
+	}
+}
+
+type scenarioFunc func(c *scenarioContext)
+type handlerFunc func(c *models.ReqContext)