Browse Source

fix(quota): fixed failing quota unit tests

Torkel Ödegaard 10 years ago
parent
commit
5e949b0564
4 changed files with 117 additions and 95 deletions
  1. 0 94
      pkg/middleware/middleware.go
  2. 106 0
      pkg/middleware/quota.go
  3. 7 1
      pkg/middleware/quota_test.go
  4. 4 0
      pkg/middleware/session.go

+ 0 - 94
pkg/middleware/middleware.go

@@ -1,7 +1,6 @@
 package middleware
 package middleware
 
 
 import (
 import (
-	"fmt"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 
 
@@ -254,96 +253,3 @@ func (ctx *Context) JsonApiErr(status int, message string, err error) {
 
 
 	ctx.JSON(status, resp)
 	ctx.JSON(status, resp)
 }
 }
-
-func Quota(target string) macaron.Handler {
-	return func(c *Context) {
-		limitReached, err := QuotaReached(c, target)
-		if err != nil {
-			c.JsonApiErr(500, "failed to get quota", err)
-			return
-		}
-		if limitReached {
-			c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
-			return
-		}
-	}
-}
-
-func QuotaReached(c *Context, target string) (bool, error) {
-	if !setting.Quota.Enabled {
-		return false, nil
-	}
-
-	// get the list of scopes that this target is valid for. Org, User, Global
-	scopes, err := m.GetQuotaScopes(target)
-	if err != nil {
-		return false, err
-	}
-	log.Info(fmt.Sprintf("checking quota for %s in scopes %v", target, scopes))
-
-	for _, scope := range scopes {
-		log.Info(fmt.Sprintf("checking scope %s", scope.Name))
-		switch scope.Name {
-		case "global":
-			if scope.DefaultLimit < 0 {
-				continue
-			}
-			if scope.DefaultLimit == 0 {
-				return true, nil
-			}
-			if target == "session" {
-				usedSessions := sessionManager.Count()
-				if int64(usedSessions) > scope.DefaultLimit {
-					log.Info(fmt.Sprintf("%d sessions active, limit is %d", usedSessions, scope.DefaultLimit))
-					return true, nil
-				}
-				continue
-			}
-			query := m.GetGlobalQuotaByTargetQuery{Target: scope.Target}
-			if err := bus.Dispatch(&query); err != nil {
-				return true, err
-			}
-			if query.Result.Used >= scope.DefaultLimit {
-				return true, nil
-			}
-		case "org":
-			if !c.IsSignedIn {
-				continue
-			}
-			query := m.GetOrgQuotaByTargetQuery{OrgId: c.OrgId, Target: scope.Target, Default: scope.DefaultLimit}
-			if err := bus.Dispatch(&query); err != nil {
-				return true, err
-			}
-			if query.Result.Limit < 0 {
-				continue
-			}
-			if query.Result.Limit == 0 {
-				return true, nil
-			}
-
-			if query.Result.Used >= query.Result.Limit {
-				return true, nil
-			}
-		case "user":
-			if !c.IsSignedIn || c.UserId == 0 {
-				continue
-			}
-			query := m.GetUserQuotaByTargetQuery{UserId: c.UserId, Target: scope.Target, Default: scope.DefaultLimit}
-			if err := bus.Dispatch(&query); err != nil {
-				return true, err
-			}
-			if query.Result.Limit < 0 {
-				continue
-			}
-			if query.Result.Limit == 0 {
-				return true, nil
-			}
-
-			if query.Result.Used >= query.Result.Limit {
-				return true, nil
-			}
-		}
-	}
-
-	return false, nil
-}

+ 106 - 0
pkg/middleware/quota.go

@@ -0,0 +1,106 @@
+package middleware
+
+import (
+	"fmt"
+
+	"github.com/Unknwon/macaron"
+	"github.com/grafana/grafana/pkg/bus"
+	"github.com/grafana/grafana/pkg/log"
+	m "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+)
+
+func Quota(target string) macaron.Handler {
+	return func(c *Context) {
+		limitReached, err := QuotaReached(c, target)
+		if err != nil {
+			c.JsonApiErr(500, "failed to get quota", err)
+			return
+		}
+		if limitReached {
+			c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
+			return
+		}
+	}
+}
+
+func QuotaReached(c *Context, target string) (bool, error) {
+	if !setting.Quota.Enabled {
+		return false, nil
+	}
+
+	// get the list of scopes that this target is valid for. Org, User, Global
+	scopes, err := m.GetQuotaScopes(target)
+	if err != nil {
+		return false, err
+	}
+
+	log.Debug(fmt.Sprintf("checking quota for %s in scopes %v", target, scopes))
+
+	for _, scope := range scopes {
+		log.Debug(fmt.Sprintf("checking scope %s", scope.Name))
+
+		switch scope.Name {
+		case "global":
+			if scope.DefaultLimit < 0 {
+				continue
+			}
+			if scope.DefaultLimit == 0 {
+				return true, nil
+			}
+			if target == "session" {
+				usedSessions := getSessionCount()
+				if int64(usedSessions) > scope.DefaultLimit {
+					log.Debug(fmt.Sprintf("%d sessions active, limit is %d", usedSessions, scope.DefaultLimit))
+					return true, nil
+				}
+				continue
+			}
+			query := m.GetGlobalQuotaByTargetQuery{Target: scope.Target}
+			if err := bus.Dispatch(&query); err != nil {
+				return true, err
+			}
+			if query.Result.Used >= scope.DefaultLimit {
+				return true, nil
+			}
+		case "org":
+			if !c.IsSignedIn {
+				continue
+			}
+			query := m.GetOrgQuotaByTargetQuery{OrgId: c.OrgId, Target: scope.Target, Default: scope.DefaultLimit}
+			if err := bus.Dispatch(&query); err != nil {
+				return true, err
+			}
+			if query.Result.Limit < 0 {
+				continue
+			}
+			if query.Result.Limit == 0 {
+				return true, nil
+			}
+
+			if query.Result.Used >= query.Result.Limit {
+				return true, nil
+			}
+		case "user":
+			if !c.IsSignedIn || c.UserId == 0 {
+				continue
+			}
+			query := m.GetUserQuotaByTargetQuery{UserId: c.UserId, Target: scope.Target, Default: scope.DefaultLimit}
+			if err := bus.Dispatch(&query); err != nil {
+				return true, err
+			}
+			if query.Result.Limit < 0 {
+				continue
+			}
+			if query.Result.Limit == 0 {
+				return true, nil
+			}
+
+			if query.Result.Used >= query.Result.Limit {
+				return true, nil
+			}
+		}
+	}
+
+	return false, nil
+}

+ 7 - 1
pkg/middleware/quota_test.go

@@ -1,16 +1,22 @@
 package middleware
 package middleware
 
 
 import (
 import (
+	"testing"
+
 	"github.com/grafana/grafana/pkg/bus"
 	"github.com/grafana/grafana/pkg/bus"
 	m "github.com/grafana/grafana/pkg/models"
 	m "github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/setting"
 	"github.com/grafana/grafana/pkg/setting"
 	. "github.com/smartystreets/goconvey/convey"
 	. "github.com/smartystreets/goconvey/convey"
-	"testing"
 )
 )
 
 
 func TestMiddlewareQuota(t *testing.T) {
 func TestMiddlewareQuota(t *testing.T) {
 
 
 	Convey("Given the grafana quota middleware", t, func() {
 	Convey("Given the grafana quota middleware", t, func() {
+		getSessionCount = func() int {
+			return 4
+		}
+
+		setting.AnonymousEnabled = false
 		setting.Quota = setting.QuotaSettings{
 		setting.Quota = setting.QuotaSettings{
 			Enabled: true,
 			Enabled: true,
 			Org: &setting.OrgQuota{
 			Org: &setting.OrgQuota{

+ 4 - 0
pkg/middleware/session.go

@@ -18,12 +18,16 @@ const (
 var sessionManager *session.Manager
 var sessionManager *session.Manager
 var sessionOptions *session.Options
 var sessionOptions *session.Options
 var startSessionGC func()
 var startSessionGC func()
+var getSessionCount func() int
 
 
 func init() {
 func init() {
 	startSessionGC = func() {
 	startSessionGC = func() {
 		sessionManager.GC()
 		sessionManager.GC()
 		time.AfterFunc(time.Duration(sessionOptions.Gclifetime)*time.Second, startSessionGC)
 		time.AfterFunc(time.Duration(sessionOptions.Gclifetime)*time.Second, startSessionGC)
 	}
 	}
+	getSessionCount = func() int {
+		return sessionManager.Count()
+	}
 }
 }
 
 
 func prepareOptions(opt *session.Options) *session.Options {
 func prepareOptions(opt *session.Options) *session.Options {