Преглед изворни кода

dsproxy: add mutex protection to the token caches

Daniel Lee пре 7 година
родитељ
комит
982e095f85
1 измењених фајлова са 27 додато и 18 уклоњено
  1. 27 18
      pkg/api/pluginproxy/access_token_provider.go

+ 27 - 18
pkg/api/pluginproxy/access_token_provider.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"net/url"
 	"strconv"
+	"sync"
 	"time"
 
 	"golang.org/x/oauth2"
@@ -17,10 +18,24 @@ import (
 )
 
 var (
-	tokenCache         = map[string]*jwtToken{}
-	oauthJwtTokenCache = map[string]*oauth2.Token{}
+	tokenCache = tokenCacheType{
+		cache: map[string]*jwtToken{},
+	}
+	oauthJwtTokenCache = oauthJwtTokenCacheType{
+		cache: map[string]*oauth2.Token{},
+	}
 )
 
+type tokenCacheType struct {
+	cache map[string]*jwtToken
+	sync.Mutex
+}
+
+type oauthJwtTokenCacheType struct {
+	cache map[string]*oauth2.Token
+	sync.Mutex
+}
+
 type accessTokenProvider struct {
 	route        *plugins.AppPluginRoute
 	datasourceID int64
@@ -40,7 +55,9 @@ func newAccessTokenProvider(dsID int64, pluginRoute *plugins.AppPluginRoute) *ac
 }
 
 func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) {
-	if cachedToken, found := tokenCache[provider.getAccessTokenCacheKey()]; found {
+	tokenCache.Lock()
+	defer tokenCache.Unlock()
+	if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
 		if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) {
 			logger.Info("Using token from cache")
 			return cachedToken.AccessToken, nil
@@ -79,7 +96,7 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
 
 	expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64)
 	token.ExpiresOn = time.Unix(expiresOnEpoch, 0)
-	tokenCache[provider.getAccessTokenCacheKey()] = &token
+	tokenCache.cache[provider.getAccessTokenCacheKey()] = &token
 
 	logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn)
 
@@ -87,7 +104,9 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
 }
 
 func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) {
-	if cachedToken, found := oauthJwtTokenCache[provider.getAccessTokenCacheKey()]; found {
+	oauthJwtTokenCache.Lock()
+	defer oauthJwtTokenCache.Unlock()
+	if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
 		if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) {
 			logger.Info("Using token from cache")
 			return cachedToken.AccessToken, nil
@@ -127,7 +146,9 @@ func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data
 		return "", err
 	}
 
-	oauthJwtTokenCache[provider.getAccessTokenCacheKey()] = token
+	oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
+
+	logger.Info("Got new access token", "ExpiresOn", token.Expiry)
 
 	return token.AccessToken, nil
 }
@@ -139,21 +160,9 @@ var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token,
 		return nil, err
 	}
 
-	// logger.Info("interpolatedVal", "token.AccessToken", token.AccessToken)
-
 	return token, nil
 }
 
 func (provider *accessTokenProvider) getAccessTokenCacheKey() string {
 	return fmt.Sprintf("%v_%v_%v", provider.datasourceID, provider.route.Path, provider.route.Method)
 }
-
-//Export access token lookup
-func GetAccessTokenFromCache(datasourceID int64, path string, method string) (string, error) {
-	key := fmt.Sprintf("%v_%v_%v", datasourceID, path, method)
-	if cachedToken, found := oauthJwtTokenCache[key]; found {
-		return cachedToken.AccessToken, nil
-	} else {
-		return "", fmt.Errorf("Key doesnt exist")
-	}
-}