Browse Source

Chore: refactor auth proxy (#16504)

* Chore: refactor auth proxy

Introduced the helper struct for auth_proxy middleware.
Added couple unit-tests, but it seems "integration" tests already cover
most of the code paths.

Although it might be good idea to test every bit of it, hm.
Haven't refactored the extraction of the header logic that much

Fixes #16147

* Fix: make linters happy
Oleg Gaidarenko 6 years ago
parent
commit
318182ccc9

+ 31 - 156
pkg/middleware/auth_proxy.go

@@ -1,188 +1,63 @@
 package middleware
 package middleware
 
 
 import (
 import (
-	"fmt"
-	"net"
-	"net/mail"
-	"reflect"
-	"strings"
-	"time"
-
-	"github.com/grafana/grafana/pkg/bus"
 	"github.com/grafana/grafana/pkg/infra/remotecache"
 	"github.com/grafana/grafana/pkg/infra/remotecache"
-	"github.com/grafana/grafana/pkg/login"
+	authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy"
 	m "github.com/grafana/grafana/pkg/models"
 	m "github.com/grafana/grafana/pkg/models"
-	"github.com/grafana/grafana/pkg/setting"
 )
 )
 
 
 const (
 const (
 
 
 	// cachePrefix is a prefix for the cache key
 	// cachePrefix is a prefix for the cache key
-	cachePrefix = "auth-proxy-sync-ttl:%s"
+	cachePrefix = authproxy.CachePrefix
 )
 )
 
 
 func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *m.ReqContext, orgID int64) bool {
 func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *m.ReqContext, orgID int64) bool {
-	if !setting.AuthProxyEnabled {
+	auth := authproxy.New(&authproxy.Options{
+		Store: store,
+		Ctx:   ctx,
+		OrgID: orgID,
+	})
+
+	// Bail if auth proxy is not enabled
+	if auth.IsEnabled() == false {
 		return false
 		return false
 	}
 	}
 
 
-	proxyHeaderValue := ctx.Req.Header.Get(setting.AuthProxyHeaderName)
-	if len(proxyHeaderValue) == 0 {
+	// If the there is no header - we can't move forward
+	if auth.HasHeader() == false {
 		return false
 		return false
 	}
 	}
 
 
-	// if auth proxy ip(s) defined, check if request comes from one of those
-	if err := checkAuthenticationProxy(ctx.Req.RemoteAddr, proxyHeaderValue); err != nil {
-		ctx.Handle(407, "Proxy authentication required", err)
+	// Check if allowed to continue with this IP
+	if result, err := auth.IsAllowedIP(); result == false {
+		ctx.Handle(407, err.Error(), err.DetailsError)
 		return true
 		return true
 	}
 	}
 
 
-	query := &m.GetSignedInUserQuery{OrgId: orgID}
-	cacheKey := fmt.Sprintf(cachePrefix, proxyHeaderValue)
-	userID, err := store.Get(cacheKey)
-	inCache := err == nil
-
-	// load the user if we have them
-	if inCache {
-		query.UserId = userID.(int64)
-
-		// if we're using ldap, pass authproxy login name to ldap user sync
-	} else if setting.LdapEnabled {
-		syncQuery := &m.LoginUserQuery{
-			ReqContext: ctx,
-			Username:   proxyHeaderValue,
-		}
-
-		if err := syncGrafanaUserWithLdapUser(syncQuery); err != nil {
-			if err == login.ErrInvalidCredentials {
-				ctx.Handle(500, "Unable to authenticate user", err)
-				return false
-			}
-		}
-
-		if syncQuery.User == nil {
-			ctx.Handle(500, "Failed to sync user", nil)
-			return false
-		}
-
-		query.UserId = syncQuery.User.Id
-		// no ldap, just use the info we have
-	} else {
-		extUser := &m.ExternalUserInfo{
-			AuthModule: "authproxy",
-			AuthId:     proxyHeaderValue,
-		}
-
-		if setting.AuthProxyHeaderProperty == "username" {
-			extUser.Login = proxyHeaderValue
-
-			// only set Email if it can be parsed as an email address
-			emailAddr, emailErr := mail.ParseAddress(proxyHeaderValue)
-			if emailErr == nil {
-				extUser.Email = emailAddr.Address
-			}
-		} else if setting.AuthProxyHeaderProperty == "email" {
-			extUser.Email = proxyHeaderValue
-			extUser.Login = proxyHeaderValue
-		} else {
-			ctx.Handle(500, "Auth proxy header property invalid", nil)
-			return true
-		}
-
-		for _, field := range []string{"Name", "Email", "Login"} {
-			if setting.AuthProxyHeaders[field] == "" {
-				continue
-			}
-
-			if val := ctx.Req.Header.Get(setting.AuthProxyHeaders[field]); val != "" {
-				reflect.ValueOf(extUser).Elem().FieldByName(field).SetString(val)
-			}
-		}
-
-		// add/update user in grafana
-		cmd := &m.UpsertUserCommand{
-			ReqContext:    ctx,
-			ExternalUser:  extUser,
-			SignupAllowed: setting.AuthProxyAutoSignUp,
-		}
-		err := bus.Dispatch(cmd)
-		if err != nil {
-			ctx.Handle(500, "Failed to login as user specified in auth proxy header", err)
-			return true
-		}
-
-		query.UserId = cmd.Result.Id
+	// Try to get user id from various sources
+	id, err := auth.GetUserID()
+	if err != nil {
+		ctx.Handle(500, err.Error(), err.DetailsError)
+		return true
 	}
 	}
 
 
-	if err := bus.Dispatch(query); err != nil {
-		ctx.Handle(500, "Failed to find user", err)
+	// Get full user info
+	user, err := auth.GetSignedUser(id)
+	if err != nil {
+		ctx.Handle(500, err.Error(), err.DetailsError)
 		return true
 		return true
 	}
 	}
-	ctx.SignedInUser = query.Result
-	ctx.IsSignedIn = true
 
 
-	expiration := time.Duration(-setting.AuthProxyLdapSyncTtl) * time.Minute
-	value := query.UserId
+	// Add user info to context
+	ctx.SignedInUser = user
+	ctx.IsSignedIn = true
 
 
-	// This <if> is here to make sure we do not
-	// rewrite the expiration all the time
-	if inCache == false {
-		if err = store.Set(cacheKey, value, expiration); err != nil {
-			ctx.Handle(500, "Couldn't write a user in cache key", err)
-			return true
-		}
+	// Remember user data it in cache
+	if err := auth.Remember(); err != nil {
+		ctx.Handle(500, err.Error(), err.DetailsError)
+		return true
 	}
 	}
 
 
 	return true
 	return true
 }
 }
-
-var syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
-	ldapCfg := login.LdapCfg
-	if len(ldapCfg.Servers) < 1 {
-		return fmt.Errorf("No LDAP servers available")
-	}
-
-	for _, server := range ldapCfg.Servers {
-		author := login.NewLdapAuthenticator(server)
-		if err := author.SyncUser(query); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
-func checkAuthenticationProxy(remoteAddr string, proxyHeaderValue string) error {
-	if len(strings.TrimSpace(setting.AuthProxyWhitelist)) == 0 {
-		return nil
-	}
-
-	proxies := strings.Split(setting.AuthProxyWhitelist, ",")
-	var proxyObjs []*net.IPNet
-	for _, proxy := range proxies {
-		proxyObjs = append(proxyObjs, coerceProxyAddress(proxy))
-	}
-
-	sourceIP, _, _ := net.SplitHostPort(remoteAddr)
-	sourceObj := net.ParseIP(sourceIP)
-
-	for _, proxyObj := range proxyObjs {
-		if proxyObj.Contains(sourceObj) {
-			return nil
-		}
-	}
-	return fmt.Errorf("Request for user (%s) from %s is not from the authentication proxy", proxyHeaderValue, sourceIP)
-}
-
-func coerceProxyAddress(proxyAddr string) *net.IPNet {
-	proxyAddr = strings.TrimSpace(proxyAddr)
-	if !strings.Contains(proxyAddr, "/") {
-		proxyAddr = strings.Join([]string{proxyAddr, "32"}, "/")
-	}
-
-	_, network, err := net.ParseCIDR(proxyAddr)
-	if err != nil {
-		fmt.Println(err)
-	}
-	return network
-}

+ 320 - 0
pkg/middleware/auth_proxy/auth_proxy.go

@@ -0,0 +1,320 @@
+package authproxy
+
+import (
+	"fmt"
+	"net"
+	"net/mail"
+	"reflect"
+	"strings"
+	"time"
+
+	"github.com/grafana/grafana/pkg/bus"
+	"github.com/grafana/grafana/pkg/infra/remotecache"
+	"github.com/grafana/grafana/pkg/login"
+	models "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+)
+
+const (
+
+	// CachePrefix is a prefix for the cache key
+	CachePrefix = "auth-proxy-sync-ttl:%s"
+)
+
+// AuthProxy struct
+type AuthProxy struct {
+	store  *remotecache.RemoteCache
+	ctx    *models.ReqContext
+	orgID  int64
+	header string
+
+	LDAP func(server *login.LdapServerConf) login.ILdapAuther
+
+	enabled     bool
+	whitelistIP string
+	headerType  string
+	headers     map[string]string
+	cacheTTL    int
+	ldapEnabled bool
+}
+
+// Error auth proxy specific error
+type Error struct {
+	Message      string
+	DetailsError error
+}
+
+// newError creates the Error
+func newError(message string, err error) *Error {
+	return &Error{
+		Message:      message,
+		DetailsError: err,
+	}
+}
+
+// Error returns a Error error string
+func (err *Error) Error() string {
+	return fmt.Sprintf("%s", err.Message)
+}
+
+// Options for the AuthProxy
+type Options struct {
+	Store *remotecache.RemoteCache
+	Ctx   *models.ReqContext
+	OrgID int64
+}
+
+// New instance of the AuthProxy
+func New(options *Options) *AuthProxy {
+	header := options.Ctx.Req.Header.Get(setting.AuthProxyHeaderName)
+
+	return &AuthProxy{
+		store:  options.Store,
+		ctx:    options.Ctx,
+		orgID:  options.OrgID,
+		header: header,
+
+		LDAP: login.NewLdapAuthenticator,
+
+		enabled:     setting.AuthProxyEnabled,
+		headerType:  setting.AuthProxyHeaderProperty,
+		headers:     setting.AuthProxyHeaders,
+		whitelistIP: setting.AuthProxyWhitelist,
+		cacheTTL:    setting.AuthProxyLdapSyncTtl,
+		ldapEnabled: setting.LdapEnabled,
+	}
+}
+
+// IsEnabled checks if the proxy auth is enabled
+func (auth *AuthProxy) IsEnabled() bool {
+
+	// Bail if the setting is not enabled
+	if auth.enabled == false {
+		return false
+	}
+
+	return true
+}
+
+// HasHeader checks if the we have specified header
+func (auth *AuthProxy) HasHeader() bool {
+	if len(auth.header) == 0 {
+		return false
+	}
+
+	return true
+}
+
+// IsAllowedIP compares presented IP with the whitelist one
+func (auth *AuthProxy) IsAllowedIP() (bool, *Error) {
+	ip := auth.ctx.Req.RemoteAddr
+
+	if len(strings.TrimSpace(auth.whitelistIP)) == 0 {
+		return true, nil
+	}
+
+	proxies := strings.Split(auth.whitelistIP, ",")
+	var proxyObjs []*net.IPNet
+	for _, proxy := range proxies {
+		result, err := coerceProxyAddress(proxy)
+		if err != nil {
+			return false, newError("Could not get the network", err)
+		}
+
+		proxyObjs = append(proxyObjs, result)
+	}
+
+	sourceIP, _, _ := net.SplitHostPort(ip)
+	sourceObj := net.ParseIP(sourceIP)
+
+	for _, proxyObj := range proxyObjs {
+		if proxyObj.Contains(sourceObj) {
+			return true, nil
+		}
+	}
+
+	err := fmt.Errorf(
+		"Request for user (%s) from %s is not from the authentication proxy", auth.header,
+		sourceIP,
+	)
+
+	return false, newError("Proxy authentication required", err)
+}
+
+// InCache checks if we have user in cache
+func (auth *AuthProxy) InCache() bool {
+	userID, _ := auth.GetUserIDViaCache()
+
+	if userID == 0 {
+		return false
+	}
+
+	return true
+}
+
+// getKey forms a key for the cache
+func (auth *AuthProxy) getKey() string {
+	return fmt.Sprintf(CachePrefix, auth.header)
+}
+
+// GetUserID gets user id with whatever means possible
+func (auth *AuthProxy) GetUserID() (int64, *Error) {
+	if auth.InCache() {
+
+		// Error here means absent cache - we don't need to handle that
+		id, _ := auth.GetUserIDViaCache()
+
+		return id, nil
+	}
+
+	if auth.ldapEnabled {
+		id, err := auth.GetUserIDViaLDAP()
+
+		if err == login.ErrInvalidCredentials {
+			return 0, newError("Proxy authentication required", login.ErrInvalidCredentials)
+		}
+
+		if err != nil {
+			return 0, newError("Failed to sync user", err)
+		}
+
+		return id, nil
+	}
+
+	id, err := auth.GetUserIDViaHeader()
+	if err != nil {
+		return 0, newError("Failed to login as user specified in auth proxy header", err)
+	}
+
+	return id, nil
+}
+
+// GetUserIDViaCache gets the user from cache
+func (auth *AuthProxy) GetUserIDViaCache() (int64, error) {
+	var (
+		cacheKey    = auth.getKey()
+		userID, err = auth.store.Get(cacheKey)
+	)
+
+	if err != nil {
+		return 0, err
+	}
+
+	return userID.(int64), nil
+}
+
+// GetUserIDViaLDAP gets user via LDAP request
+func (auth *AuthProxy) GetUserIDViaLDAP() (int64, *Error) {
+	query := &models.LoginUserQuery{
+		ReqContext: auth.ctx,
+		Username:   auth.header,
+	}
+
+	ldapCfg := login.LdapCfg
+	if len(ldapCfg.Servers) < 1 {
+		return 0, newError("No LDAP servers available", nil)
+	}
+
+	for _, server := range ldapCfg.Servers {
+		author := auth.LDAP(server)
+		if err := author.SyncUser(query); err != nil {
+			return 0, newError(err.Error(), nil)
+		}
+	}
+
+	return query.User.Id, nil
+}
+
+// GetUserIDViaHeader gets user from the header only
+func (auth *AuthProxy) GetUserIDViaHeader() (int64, error) {
+	extUser := &models.ExternalUserInfo{
+		AuthModule: "authproxy",
+		AuthId:     auth.header,
+	}
+
+	if auth.headerType == "username" {
+		extUser.Login = auth.header
+
+		// only set Email if it can be parsed as an email address
+		emailAddr, emailErr := mail.ParseAddress(auth.header)
+		if emailErr == nil {
+			extUser.Email = emailAddr.Address
+		}
+	} else if auth.headerType == "email" {
+		extUser.Email = auth.header
+		extUser.Login = auth.header
+	} else {
+		return 0, newError("Auth proxy header property invalid", nil)
+	}
+
+	for _, field := range []string{"Name", "Email", "Login"} {
+		if auth.headers[field] == "" {
+			continue
+		}
+
+		if val := auth.ctx.Req.Header.Get(auth.headers[field]); val != "" {
+			reflect.ValueOf(extUser).Elem().FieldByName(field).SetString(val)
+		}
+	}
+
+	// add/update user in grafana
+	cmd := &models.UpsertUserCommand{
+		ReqContext:    auth.ctx,
+		ExternalUser:  extUser,
+		SignupAllowed: setting.AuthProxyAutoSignUp,
+	}
+	err := bus.Dispatch(cmd)
+	if err != nil {
+		return 0, err
+	}
+
+	return cmd.Result.Id, nil
+}
+
+// GetSignedUser get full signed user info
+func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, *Error) {
+	query := &models.GetSignedInUserQuery{
+		OrgId:  auth.orgID,
+		UserId: userID,
+	}
+
+	if err := bus.Dispatch(query); err != nil {
+		return nil, newError(err.Error(), nil)
+	}
+
+	return query.Result, nil
+}
+
+// Remember user in cache
+func (auth *AuthProxy) Remember() *Error {
+
+	// Make sure we do not rewrite the expiration time
+	if auth.InCache() {
+		return nil
+	}
+
+	var (
+		key        = auth.getKey()
+		value, _   = auth.GetUserIDViaCache()
+		expiration = time.Duration(-auth.cacheTTL) * time.Minute
+
+		err = auth.store.Set(key, value, expiration)
+	)
+
+	if err != nil {
+		return newError(err.Error(), nil)
+	}
+
+	return nil
+}
+
+// coerceProxyAddress gets network of the presented CIDR notation
+func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) {
+	proxyAddr = strings.TrimSpace(proxyAddr)
+	if !strings.Contains(proxyAddr, "/") {
+		proxyAddr = strings.Join([]string{proxyAddr, "32"}, "/")
+	}
+
+	_, network, err := net.ParseCIDR(proxyAddr)
+	return network, err
+}

+ 124 - 0
pkg/middleware/auth_proxy/auth_proxy_test.go

@@ -0,0 +1,124 @@
+package authproxy
+
+import (
+	"fmt"
+	"net/http"
+	"testing"
+
+	"github.com/grafana/grafana/pkg/infra/remotecache"
+	"github.com/grafana/grafana/pkg/login"
+	models "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+	. "github.com/smartystreets/goconvey/convey"
+	"gopkg.in/macaron.v1"
+)
+
+type TestLDAP struct {
+	login.ILdapAuther
+	ID         int64
+	syncCalled bool
+}
+
+func (stub *TestLDAP) SyncUser(query *models.LoginUserQuery) error {
+	stub.syncCalled = true
+	query.User = &models.User{
+		Id: stub.ID,
+	}
+	return nil
+}
+
+func TestMiddlewareContext(t *testing.T) {
+	Convey("auth_proxy helper", t, func() {
+		req, _ := http.NewRequest("POST", "http://example.com", nil)
+		setting.AuthProxyHeaderName = "X-Killa"
+		name := "markelog"
+
+		req.Header.Add(setting.AuthProxyHeaderName, name)
+
+		ctx := &models.ReqContext{
+			Context: &macaron.Context{
+				Req: macaron.Request{
+					Request: req,
+				},
+			},
+		}
+
+		Convey("gets data from the cache", func() {
+			store := remotecache.NewFakeStore(t)
+			key := fmt.Sprintf(CachePrefix, name)
+			store.Set(key, int64(33), 0)
+
+			auth := New(&Options{
+				Store: store,
+				Ctx:   ctx,
+				OrgID: 4,
+			})
+
+			id, err := auth.GetUserID()
+
+			So(err, ShouldBeNil)
+			So(id, ShouldEqual, 33)
+		})
+
+		Convey("LDAP", func() {
+			Convey("gets data from the LDAP", func() {
+				login.LdapCfg = login.LdapConfig{
+					Servers: []*login.LdapServerConf{
+						{},
+					},
+				}
+
+				setting.LdapEnabled = true
+
+				store := remotecache.NewFakeStore(t)
+
+				auth := New(&Options{
+					Store: store,
+					Ctx:   ctx,
+					OrgID: 4,
+				})
+
+				stub := &TestLDAP{
+					ID: 42,
+				}
+
+				auth.LDAP = func(server *login.LdapServerConf) login.ILdapAuther {
+					return stub
+				}
+
+				id, err := auth.GetUserID()
+
+				So(err, ShouldBeNil)
+				So(id, ShouldEqual, 42)
+				So(stub.syncCalled, ShouldEqual, true)
+			})
+
+			Convey("gets nice error if ldap is enabled but not configured", func() {
+				setting.LdapEnabled = false
+
+				store := remotecache.NewFakeStore(t)
+
+				auth := New(&Options{
+					Store: store,
+					Ctx:   ctx,
+					OrgID: 4,
+				})
+
+				stub := &TestLDAP{
+					ID: 42,
+				}
+
+				auth.LDAP = func(server *login.LdapServerConf) login.ILdapAuther {
+					return stub
+				}
+
+				id, err := auth.GetUserID()
+
+				So(err, ShouldNotBeNil)
+				So(id, ShouldNotEqual, 42)
+				So(stub.syncCalled, ShouldEqual, false)
+			})
+
+		})
+	})
+}

+ 2 - 52
pkg/middleware/middleware_test.go

@@ -276,52 +276,9 @@ func TestMiddlewareContext(t *testing.T) {
 			setting.AuthProxyHeaderProperty = "username"
 			setting.AuthProxyHeaderProperty = "username"
 			name := "markelog"
 			name := "markelog"
 
 
-			middlewareScenario(t, "should sync the user if it's not in the cache", func(sc *scenarioContext) {
-				called := false
-				syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
-					called = true
-					query.User = &m.User{Id: 32}
-					return nil
-				}
-
-				bus.AddHandler("test", func(query *m.UpsertUserCommand) error {
-					query.Result = &m.User{Id: 32}
-					return nil
-				})
-
-				bus.AddHandler("test", func(query *m.GetSignedInUserQuery) error {
-					query.Result = &m.SignedInUser{OrgId: 4, UserId: 32}
-					return nil
-				})
-
-				sc.fakeReq("GET", "/")
-
-				sc.req.Header.Add(setting.AuthProxyHeaderName, name)
-				sc.exec()
-
-				Convey("Should init user via ldap", func() {
-					So(called, ShouldBeTrue)
-					So(sc.context.IsSignedIn, ShouldBeTrue)
-					So(sc.context.UserId, ShouldEqual, 32)
-					So(sc.context.OrgId, ShouldEqual, 4)
-				})
-			})
-
 			middlewareScenario(t, "should not sync the user if it's in the cache", func(sc *scenarioContext) {
 			middlewareScenario(t, "should not sync the user if it's in the cache", func(sc *scenarioContext) {
-				called := false
-				syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
-					called = true
-					query.User = &m.User{Id: 32}
-					return nil
-				}
-
-				bus.AddHandler("test", func(query *m.UpsertUserCommand) error {
-					query.Result = &m.User{Id: 32}
-					return nil
-				})
-
 				bus.AddHandler("test", func(query *m.GetSignedInUserQuery) error {
 				bus.AddHandler("test", func(query *m.GetSignedInUserQuery) error {
-					query.Result = &m.SignedInUser{OrgId: 4, UserId: 32}
+					query.Result = &m.SignedInUser{OrgId: 4, UserId: query.UserId}
 					return nil
 					return nil
 				})
 				})
 
 
@@ -332,17 +289,10 @@ func TestMiddlewareContext(t *testing.T) {
 				sc.req.Header.Add(setting.AuthProxyHeaderName, name)
 				sc.req.Header.Add(setting.AuthProxyHeaderName, name)
 				sc.exec()
 				sc.exec()
 
 
-				cacheValue, cacheErr := sc.remoteCacheService.Get(key)
-
 				Convey("Should init user via cache", func() {
 				Convey("Should init user via cache", func() {
-					So(called, ShouldBeFalse)
-
 					So(sc.context.IsSignedIn, ShouldBeTrue)
 					So(sc.context.IsSignedIn, ShouldBeTrue)
-					So(sc.context.UserId, ShouldEqual, 32)
+					So(sc.context.UserId, ShouldEqual, 33)
 					So(sc.context.OrgId, ShouldEqual, 4)
 					So(sc.context.OrgId, ShouldEqual, 4)
-
-					So(cacheValue, ShouldEqual, 33)
-					So(cacheErr, ShouldBeNil)
 				})
 				})
 			})
 			})