Prechádzať zdrojové kódy

Re-organize packages and add basic auth test

Bob Shannon 7 rokov pred
rodič
commit
16ff8a182b

+ 2 - 2
pkg/util/auth.go → pkg/api/basic_auth.go

@@ -1,11 +1,11 @@
-package util
+package api
 
 import (
 	"crypto/subtle"
 	macaron "gopkg.in/macaron.v1"
 )
 
-// BasicAuthenticated parses the provided HTTP request for basic authentication credentials
+// BasicAuthenticatedRequest parses the provided HTTP request for basic authentication credentials
 // and returns true if the provided credentials match the expected username and password.
 // Returns false if the request is unauthenticated.
 // Uses constant-time comparison in order to mitigate timing attacks.

+ 45 - 0
pkg/api/basic_auth_test.go

@@ -0,0 +1,45 @@
+package api
+
+import (
+	"encoding/base64"
+	"fmt"
+	"net/http"
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+	"gopkg.in/macaron.v1"
+)
+
+func TestBasicAuthenticatedRequest(t *testing.T) {
+	expectedUser := "prometheus"
+	expectedPass := "password"
+
+	Convey("Given a valid set of basic auth credentials", t, func() {
+		httpReq, err := http.NewRequest("GET", "http://localhost:3000/metrics", nil)
+		So(err, ShouldBeNil)
+		req := macaron.Request{
+			Request: httpReq,
+		}
+		encodedCreds := encodeBasicAuthCredentials(expectedUser, expectedPass)
+		req.Header.Add("Authorization", fmt.Sprintf("Basic %s", encodedCreds))
+		authenticated := BasicAuthenticatedRequest(req, expectedUser, expectedPass)
+		So(authenticated, ShouldBeTrue)
+	})
+
+	Convey("Given an invalid set of basic auth credentials", t, func() {
+		httpReq, err := http.NewRequest("GET", "http://localhost:3000/metrics", nil)
+		So(err, ShouldBeNil)
+		req := macaron.Request{
+			Request: httpReq,
+		}
+		encodedCreds := encodeBasicAuthCredentials("invaliduser", "invalidpass")
+		req.Header.Add("Authorization", fmt.Sprintf("Basic %s", encodedCreds))
+		authenticated := BasicAuthenticatedRequest(req, expectedUser, expectedPass)
+		So(authenticated, ShouldBeFalse)
+	})
+}
+
+func encodeBasicAuthCredentials(user, pass string) string {
+	creds := fmt.Sprintf("%s:%s", user, pass)
+	return base64.StdEncoding.EncodeToString([]byte(creds))
+}

+ 5 - 4
pkg/api/http_server.go

@@ -32,7 +32,6 @@ import (
 	"github.com/grafana/grafana/pkg/services/hooks"
 	"github.com/grafana/grafana/pkg/services/rendering"
 	"github.com/grafana/grafana/pkg/setting"
-	"github.com/grafana/grafana/pkg/util"
 )
 
 func init() {
@@ -246,9 +245,7 @@ func (hs *HTTPServer) metricsEndpoint(ctx *macaron.Context) {
 		return
 	}
 
-	if hs.Cfg.MetricsEndpointBasicAuthUsername != "" &&
-		hs.Cfg.MetricsEndpointBasicAuthPassword != "" &&
-		!util.BasicAuthenticatedRequest(ctx.Req, hs.Cfg.MetricsEndpointBasicAuthUsername, hs.Cfg.MetricsEndpointBasicAuthPassword) {
+	if hs.metricsEndpointBasicAuthEnabled() && !BasicAuthenticatedRequest(ctx.Req, hs.Cfg.MetricsEndpointBasicAuthUsername, hs.Cfg.MetricsEndpointBasicAuthPassword) {
 		ctx.Resp.WriteHeader(http.StatusUnauthorized)
 		return
 	}
@@ -307,3 +304,7 @@ func (hs *HTTPServer) mapStatic(m *macaron.Macaron, rootDir string, dir string,
 		},
 	))
 }
+
+func (hs *HTTPServer) metricsEndpointBasicAuthEnabled() bool {
+	return hs.Cfg.MetricsEndpointBasicAuthUsername != "" && hs.Cfg.MetricsEndpointBasicAuthPassword != ""
+}

+ 30 - 0
pkg/api/http_server_test.go

@@ -0,0 +1,30 @@
+package api
+
+import (
+	"testing"
+
+	"github.com/grafana/grafana/pkg/setting"
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+func TestHTTPServer(t *testing.T) {
+	Convey("Given a HTTPServer", t, func() {
+		ts := &HTTPServer{
+			Cfg: setting.NewCfg(),
+		}
+
+		Convey("Given that basic auth on the metrics endpoint is enabled", func() {
+			ts.Cfg.MetricsEndpointBasicAuthUsername = "foo"
+			ts.Cfg.MetricsEndpointBasicAuthPassword = "bar"
+
+			So(ts.metricsEndpointBasicAuthEnabled(), ShouldBeTrue)
+		})
+
+		Convey("Given that basic auth on the metrics endpoint is disabled", func() {
+			ts.Cfg.MetricsEndpointBasicAuthUsername = ""
+			ts.Cfg.MetricsEndpointBasicAuthPassword = ""
+
+			So(ts.metricsEndpointBasicAuthEnabled(), ShouldBeFalse)
+		})
+	})
+}