瀏覽代碼

Add function in ds_proxy to handle oauthPassThru headers

Sean Lafferty 6 年之前
父節點
當前提交
7e62394d01
共有 1 個文件被更改,包括 44 次插入36 次删除
  1. 44 36
      pkg/api/pluginproxy/ds_proxy.go

+ 44 - 36
pkg/api/pluginproxy/ds_proxy.go

@@ -220,42 +220,7 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
 		}
 
 		if proxy.ds.JsonData != nil && proxy.ds.JsonData.Get("oauthPassThru").MustBool() {
-			cmd := &m.GetAuthInfoQuery{UserId: proxy.ctx.UserId}
-			if err := bus.Dispatch(cmd); err != nil {
-				logger.Error("Error feching oauth information for user", "error", err)
-			}
-
-			provider := cmd.Result.AuthModule
-			connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
-			if !ok {
-				logger.Error("Failed to find oauth provider with given name", "provider", provider)
-			}
-
-			// TokenSource handles refreshing the token if it has expired
-			token, err := connect.TokenSource(proxy.ctx.Req.Context(), &oauth2.Token{
-				AccessToken:  cmd.Result.OAuthAccessToken,
-				Expiry:       cmd.Result.OAuthExpiry,
-				RefreshToken: cmd.Result.OAuthRefreshToken,
-				TokenType:    cmd.Result.OAuthTokenType,
-			}).Token()
-			if err != nil {
-				logger.Error("Failed to retrieve access token from oauth provider", "provider", cmd.Result.AuthModule)
-			}
-
-			// If the tokens are not the same, update the entry in the DB
-			if token.AccessToken != cmd.Result.OAuthAccessToken {
-				cmd2 := &m.UpdateAuthInfoCommand{
-					UserId:     cmd.Result.Id,
-					AuthModule: cmd.Result.AuthModule,
-					AuthId:     cmd.Result.AuthId,
-					OAuthToken: token,
-				}
-				if err := bus.Dispatch(cmd2); err != nil {
-					logger.Error("Failed to update access token during token refresh", "error", err)
-				}
-			}
-			req.Header.Del("Authorization")
-			req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
+			addOAuthPassThruAuth(proxy.ctx, req)
 		}
 	}
 }
@@ -347,3 +312,46 @@ func checkWhiteList(c *m.ReqContext, host string) bool {
 
 	return true
 }
+
+func addOAuthPassThruAuth(c *m.ReqContext, req *http.Request) {
+	cmd := &m.GetAuthInfoQuery{UserId: c.UserId}
+	if err := bus.Dispatch(cmd); err != nil {
+		logger.Error("Error feching oauth information for user", "error", err)
+		return
+	}
+
+	provider := cmd.Result.AuthModule
+	connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
+	if !ok {
+		logger.Error("Failed to find oauth provider with given name", "provider", provider)
+		return
+	}
+
+	// TokenSource handles refreshing the token if it has expired
+	token, err := connect.TokenSource(c.Req.Context(), &oauth2.Token{
+		AccessToken:  cmd.Result.OAuthAccessToken,
+		Expiry:       cmd.Result.OAuthExpiry,
+		RefreshToken: cmd.Result.OAuthRefreshToken,
+		TokenType:    cmd.Result.OAuthTokenType,
+	}).Token()
+	if err != nil {
+		logger.Error("Failed to retrieve access token from oauth provider", "provider", cmd.Result.AuthModule)
+		return
+	}
+
+	// If the tokens are not the same, update the entry in the DB
+	if token.AccessToken != cmd.Result.OAuthAccessToken {
+		cmd2 := &m.UpdateAuthInfoCommand{
+			UserId:     cmd.Result.Id,
+			AuthModule: cmd.Result.AuthModule,
+			AuthId:     cmd.Result.AuthId,
+			OAuthToken: token,
+		}
+		if err := bus.Dispatch(cmd2); err != nil {
+			logger.Error("Failed to update access token during token refresh", "error", err)
+			return
+		}
+	}
+	req.Header.Del("Authorization")
+	req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
+}