access_token_provider.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package pluginproxy
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "net/url"
  9. "strconv"
  10. "sync"
  11. "time"
  12. "golang.org/x/oauth2"
  13. "github.com/grafana/grafana/pkg/plugins"
  14. "golang.org/x/oauth2/jwt"
  15. )
  16. var (
  17. tokenCache = tokenCacheType{
  18. cache: map[string]*jwtToken{},
  19. }
  20. oauthJwtTokenCache = oauthJwtTokenCacheType{
  21. cache: map[string]*oauth2.Token{},
  22. }
  23. )
  24. type tokenCacheType struct {
  25. cache map[string]*jwtToken
  26. sync.Mutex
  27. }
  28. type oauthJwtTokenCacheType struct {
  29. cache map[string]*oauth2.Token
  30. sync.Mutex
  31. }
  32. type accessTokenProvider struct {
  33. route *plugins.AppPluginRoute
  34. datasourceID int64
  35. }
  36. type jwtToken struct {
  37. ExpiresOn time.Time `json:"-"`
  38. ExpiresOnString string `json:"expires_on"`
  39. AccessToken string `json:"access_token"`
  40. }
  41. func newAccessTokenProvider(dsID int64, pluginRoute *plugins.AppPluginRoute) *accessTokenProvider {
  42. return &accessTokenProvider{
  43. datasourceID: dsID,
  44. route: pluginRoute,
  45. }
  46. }
  47. func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) {
  48. tokenCache.Lock()
  49. defer tokenCache.Unlock()
  50. if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
  51. if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) {
  52. logger.Info("Using token from cache")
  53. return cachedToken.AccessToken, nil
  54. }
  55. }
  56. urlInterpolated, err := interpolateString(provider.route.TokenAuth.Url, data)
  57. if err != nil {
  58. return "", err
  59. }
  60. params := make(url.Values)
  61. for key, value := range provider.route.TokenAuth.Params {
  62. interpolatedParam, err := interpolateString(value, data)
  63. if err != nil {
  64. return "", err
  65. }
  66. params.Add(key, interpolatedParam)
  67. }
  68. getTokenReq, _ := http.NewRequest("POST", urlInterpolated, bytes.NewBufferString(params.Encode()))
  69. getTokenReq.Header.Add("Content-Type", "application/x-www-form-urlencoded")
  70. getTokenReq.Header.Add("Content-Length", strconv.Itoa(len(params.Encode())))
  71. resp, err := client.Do(getTokenReq)
  72. if err != nil {
  73. return "", err
  74. }
  75. defer resp.Body.Close()
  76. var token jwtToken
  77. if err := json.NewDecoder(resp.Body).Decode(&token); err != nil {
  78. return "", err
  79. }
  80. expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64)
  81. token.ExpiresOn = time.Unix(expiresOnEpoch, 0)
  82. tokenCache.cache[provider.getAccessTokenCacheKey()] = &token
  83. logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn)
  84. return token.AccessToken, nil
  85. }
  86. func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) {
  87. oauthJwtTokenCache.Lock()
  88. defer oauthJwtTokenCache.Unlock()
  89. if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
  90. if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) {
  91. logger.Info("Using token from cache")
  92. return cachedToken.AccessToken, nil
  93. }
  94. }
  95. conf := &jwt.Config{}
  96. if val, ok := provider.route.JwtTokenAuth.Params["client_email"]; ok {
  97. interpolatedVal, err := interpolateString(val, data)
  98. if err != nil {
  99. return "", err
  100. }
  101. conf.Email = interpolatedVal
  102. }
  103. if val, ok := provider.route.JwtTokenAuth.Params["private_key"]; ok {
  104. interpolatedVal, err := interpolateString(val, data)
  105. if err != nil {
  106. return "", err
  107. }
  108. conf.PrivateKey = []byte(interpolatedVal)
  109. }
  110. if val, ok := provider.route.JwtTokenAuth.Params["token_uri"]; ok {
  111. interpolatedVal, err := interpolateString(val, data)
  112. if err != nil {
  113. return "", err
  114. }
  115. conf.TokenURL = interpolatedVal
  116. }
  117. conf.Scopes = provider.route.JwtTokenAuth.Scopes
  118. token, err := getTokenSource(conf, ctx)
  119. if err != nil {
  120. return "", err
  121. }
  122. oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
  123. logger.Info("Got new access token", "ExpiresOn", token.Expiry)
  124. return token.AccessToken, nil
  125. }
  126. var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
  127. tokenSrc := conf.TokenSource(ctx)
  128. token, err := tokenSrc.Token()
  129. if err != nil {
  130. return nil, err
  131. }
  132. return token, nil
  133. }
  134. func (provider *accessTokenProvider) getAccessTokenCacheKey() string {
  135. return fmt.Sprintf("%v_%v_%v", provider.datasourceID, provider.route.Path, provider.route.Method)
  136. }