浏览代码

Support large github organisations (#8846)

* Add new HttpGetResponse struct type
* Modify HttpGet() return to use HttpGetResponse
* Look up _all_ the teams the user is a member of
Dave Hall 8 年之前
父节点
当前提交
0c70d271dc
共有 5 个文件被更改,包括 80 次插入36 次删除
  1. 15 6
      pkg/social/common.go
  2. 9 9
      pkg/social/generic_oauth.go
  3. 52 17
      pkg/social/github_oauth.go
  4. 2 2
      pkg/social/google_oauth.go
  5. 2 2
      pkg/social/grafana_com_oauth.go

+ 15 - 6
pkg/social/common.go

@@ -9,6 +9,11 @@ import (
 	"github.com/grafana/grafana/pkg/log"
 )
 
+type HttpGetResponse struct {
+	Body    []byte
+	Headers http.Header
+}
+
 func isEmailAllowed(email string, allowedDomains []string) bool {
 	if len(allowedDomains) == 0 {
 		return true
@@ -23,24 +28,28 @@ func isEmailAllowed(email string, allowedDomains []string) bool {
 	return valid
 }
 
-func HttpGet(client *http.Client, url string) ([]byte, error) {
+func HttpGet(client *http.Client, url string) (response HttpGetResponse, err error) {
 	r, err := client.Get(url)
 	if err != nil {
-		return nil, err
+		return
 	}
 
 	defer r.Body.Close()
 
 	body, err := ioutil.ReadAll(r.Body)
 	if err != nil {
-		return nil, err
+		return
 	}
 
+	response = HttpGetResponse{body, r.Header}
+
 	if r.StatusCode >= 300 {
-		return nil, fmt.Errorf(string(body))
+		err = fmt.Errorf(string(response.Body))
+		return
 	}
 
-	log.Trace("HTTP GET %s: %s %s", url, r.Status, string(body))
+	log.Trace("HTTP GET %s: %s %s", url, r.Status, string(response.Body))
 
-	return body, nil
+	err = nil
+	return
 }

+ 9 - 9
pkg/social/generic_oauth.go

@@ -83,20 +83,20 @@ func (s *GenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
 		IsConfirmed bool   `json:"is_confirmed"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
+	response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
 	if err != nil {
 		return "", fmt.Errorf("Error getting email address: %s", err)
 	}
 
 	var records []Record
 
-	err = json.Unmarshal(body, &records)
+	err = json.Unmarshal(response.Body, &records)
 	if err != nil {
 		var data struct {
 			Values []Record `json:"values"`
 		}
 
-		err = json.Unmarshal(body, &data)
+		err = json.Unmarshal(response.Body, &data)
 		if err != nil {
 			return "", fmt.Errorf("Error getting email address: %s", err)
 		}
@@ -120,14 +120,14 @@ func (s *GenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error)
 		Id int `json:"id"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
+	response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
 	if err != nil {
 		return nil, fmt.Errorf("Error getting team memberships: %s", err)
 	}
 
 	var records []Record
 
-	err = json.Unmarshal(body, &records)
+	err = json.Unmarshal(response.Body, &records)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting team memberships: %s", err)
 	}
@@ -145,14 +145,14 @@ func (s *GenericOAuth) FetchOrganizations(client *http.Client) ([]string, error)
 		Login string `json:"login"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
+	response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
 	if err != nil {
 		return nil, fmt.Errorf("Error getting organizations: %s", err)
 	}
 
 	var records []Record
 
-	err = json.Unmarshal(body, &records)
+	err = json.Unmarshal(response.Body, &records)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting organizations: %s", err)
 	}
@@ -175,12 +175,12 @@ func (s *GenericOAuth) UserInfo(client *http.Client) (*BasicUserInfo, error) {
 		Attributes  map[string][]string `json:"attributes"`
 	}
 
-	body, err := HttpGet(client, s.apiUrl)
+	response, err := HttpGet(client, s.apiUrl)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}
 
-	err = json.Unmarshal(body, &data)
+	err = json.Unmarshal(response.Body, &data)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}

+ 52 - 17
pkg/social/github_oauth.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"net/http"
+	"regexp"
 
 	"github.com/grafana/grafana/pkg/models"
 
@@ -85,14 +86,14 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
 		Verified bool   `json:"verified"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
+	response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
 	if err != nil {
 		return "", fmt.Errorf("Error getting email address: %s", err)
 	}
 
 	var records []Record
 
-	err = json.Unmarshal(body, &records)
+	err = json.Unmarshal(response.Body, &records)
 	if err != nil {
 		return "", fmt.Errorf("Error getting email address: %s", err)
 	}
@@ -112,24 +113,58 @@ func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]int, error)
 		Id int `json:"id"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
-	if err != nil {
-		return nil, fmt.Errorf("Error getting team memberships: %s", err)
+	url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
+	hasMore := true
+	ids := make([]int, 0)
+
+	for hasMore {
+
+		response, err := HttpGet(client, url)
+		if err != nil {
+			return nil, fmt.Errorf("Error getting team memberships: %s", err)
+		}
+
+		var records []Record
+
+		err = json.Unmarshal(response.Body, &records)
+		if err != nil {
+			return nil, fmt.Errorf("Error getting team memberships: %s", err)
+		}
+
+		newRecords := len(records)
+		existingRecords := len(ids)
+		tempIds := make([]int, (newRecords + existingRecords))
+		copy(tempIds, ids)
+		ids = tempIds
+
+		for i, record := range records {
+			ids[i] = record.Id
+		}
+
+		url, hasMore = s.HasMoreRecords(response.Headers)
 	}
 
-	var records []Record
+	return ids, nil
+}
 
-	err = json.Unmarshal(body, &records)
-	if err != nil {
-		return nil, fmt.Errorf("Error getting team memberships: %s", err)
+func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) {
+
+	value, exists := headers["Link"]
+	if !exists {
+		return "", false
 	}
 
-	var ids = make([]int, len(records))
-	for i, record := range records {
-		ids[i] = record.Id
+	pattern := regexp.MustCompile(`<([^>]+)>; rel="next"`)
+	matches := pattern.FindStringSubmatch(value[0])
+
+	if matches == nil {
+		return "", false
 	}
 
-	return ids, nil
+	url := matches[1]
+
+	return url, true
+
 }
 
 func (s *SocialGithub) FetchOrganizations(client *http.Client) ([]string, error) {
@@ -137,14 +172,14 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client) ([]string, error)
 		Login string `json:"login"`
 	}
 
-	body, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
+	response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
 	if err != nil {
 		return nil, fmt.Errorf("Error getting organizations: %s", err)
 	}
 
 	var records []Record
 
-	err = json.Unmarshal(body, &records)
+	err = json.Unmarshal(response.Body, &records)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting organizations: %s", err)
 	}
@@ -164,12 +199,12 @@ func (s *SocialGithub) UserInfo(client *http.Client) (*BasicUserInfo, error) {
 		Email string `json:"email"`
 	}
 
-	body, err := HttpGet(client, s.apiUrl)
+	response, err := HttpGet(client, s.apiUrl)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}
 
-	err = json.Unmarshal(body, &data)
+	err = json.Unmarshal(response.Body, &data)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}

+ 2 - 2
pkg/social/google_oauth.go

@@ -36,12 +36,12 @@ func (s *SocialGoogle) UserInfo(client *http.Client) (*BasicUserInfo, error) {
 		Email string `json:"email"`
 	}
 
-	body, err := HttpGet(client, s.apiUrl)
+	response, err := HttpGet(client, s.apiUrl)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}
 
-	err = json.Unmarshal(body, &data)
+	err = json.Unmarshal(response.Body, &data)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}

+ 2 - 2
pkg/social/grafana_com_oauth.go

@@ -58,12 +58,12 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client) (*BasicUserInfo, error)
 		Orgs  []OrgRecord `json:"orgs"`
 	}
 
-	body, err := HttpGet(client, s.url+"/api/oauth2/user")
+	response, err := HttpGet(client, s.url+"/api/oauth2/user")
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}
 
-	err = json.Unmarshal(body, &data)
+	err = json.Unmarshal(response.Body, &data)
 	if err != nil {
 		return nil, fmt.Errorf("Error getting user info: %s", err)
 	}