generic_oauth.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package social
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "net/mail"
  9. "regexp"
  10. "github.com/grafana/grafana/pkg/models"
  11. "github.com/jmespath/go-jmespath"
  12. "golang.org/x/oauth2"
  13. )
  14. type SocialGenericOAuth struct {
  15. *SocialBase
  16. allowedDomains []string
  17. allowedOrganizations []string
  18. apiUrl string
  19. allowSignup bool
  20. emailAttributeName string
  21. emailAttributePath string
  22. teamIds []int
  23. }
  24. func (s *SocialGenericOAuth) Type() int {
  25. return int(models.GENERIC)
  26. }
  27. func (s *SocialGenericOAuth) IsEmailAllowed(email string) bool {
  28. return isEmailAllowed(email, s.allowedDomains)
  29. }
  30. func (s *SocialGenericOAuth) IsSignupAllowed() bool {
  31. return s.allowSignup
  32. }
  33. func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
  34. if len(s.teamIds) == 0 {
  35. return true
  36. }
  37. teamMemberships, err := s.FetchTeamMemberships(client)
  38. if err != nil {
  39. return false
  40. }
  41. for _, teamId := range s.teamIds {
  42. for _, membershipId := range teamMemberships {
  43. if teamId == membershipId {
  44. return true
  45. }
  46. }
  47. }
  48. return false
  49. }
  50. func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
  51. if len(s.allowedOrganizations) == 0 {
  52. return true
  53. }
  54. organizations, err := s.FetchOrganizations(client)
  55. if err != nil {
  56. return false
  57. }
  58. for _, allowedOrganization := range s.allowedOrganizations {
  59. for _, organization := range organizations {
  60. if organization == allowedOrganization {
  61. return true
  62. }
  63. }
  64. }
  65. return false
  66. }
  67. // searchJSONForEmail searches the provided JSON response for an e-mail address
  68. // using the configured e-mail attribute path associated with the generic OAuth
  69. // provider.
  70. // Returns an empty string if an e-mail address is not found.
  71. func (s *SocialGenericOAuth) searchJSONForEmail(data []byte) string {
  72. if s.emailAttributePath == "" {
  73. s.log.Error("No e-mail attribute path specified")
  74. return ""
  75. }
  76. if len(data) == 0 {
  77. s.log.Error("Empty user info JSON response provided")
  78. return ""
  79. }
  80. var buf interface{}
  81. if err := json.Unmarshal(data, &buf); err != nil {
  82. s.log.Error("Failed to unmarshal user info JSON response", "err", err.Error())
  83. return ""
  84. }
  85. val, err := jmespath.Search(s.emailAttributePath, buf)
  86. if err != nil {
  87. s.log.Error("Failed to search user info JSON response with provided path", "emailAttributePath", s.emailAttributePath, "err", err.Error())
  88. return ""
  89. }
  90. strVal, ok := val.(string)
  91. if ok {
  92. return strVal
  93. }
  94. s.log.Error("E-mail not found when searching JSON with provided path", "emailAttributePath", s.emailAttributePath)
  95. return ""
  96. }
  97. func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
  98. type Record struct {
  99. Email string `json:"email"`
  100. Primary bool `json:"primary"`
  101. IsPrimary bool `json:"is_primary"`
  102. Verified bool `json:"verified"`
  103. IsConfirmed bool `json:"is_confirmed"`
  104. }
  105. response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
  106. if err != nil {
  107. return "", fmt.Errorf("Error getting email address: %s", err)
  108. }
  109. var records []Record
  110. err = json.Unmarshal(response.Body, &records)
  111. if err != nil {
  112. var data struct {
  113. Values []Record `json:"values"`
  114. }
  115. err = json.Unmarshal(response.Body, &data)
  116. if err != nil {
  117. return "", fmt.Errorf("Error getting email address: %s", err)
  118. }
  119. records = data.Values
  120. }
  121. var email = ""
  122. for _, record := range records {
  123. if record.Primary || record.IsPrimary {
  124. email = record.Email
  125. break
  126. }
  127. }
  128. return email, nil
  129. }
  130. func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error) {
  131. type Record struct {
  132. Id int `json:"id"`
  133. }
  134. response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
  135. if err != nil {
  136. return nil, fmt.Errorf("Error getting team memberships: %s", err)
  137. }
  138. var records []Record
  139. err = json.Unmarshal(response.Body, &records)
  140. if err != nil {
  141. return nil, fmt.Errorf("Error getting team memberships: %s", err)
  142. }
  143. var ids = make([]int, len(records))
  144. for i, record := range records {
  145. ids[i] = record.Id
  146. }
  147. return ids, nil
  148. }
  149. func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, error) {
  150. type Record struct {
  151. Login string `json:"login"`
  152. }
  153. response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
  154. if err != nil {
  155. return nil, fmt.Errorf("Error getting organizations: %s", err)
  156. }
  157. var records []Record
  158. err = json.Unmarshal(response.Body, &records)
  159. if err != nil {
  160. return nil, fmt.Errorf("Error getting organizations: %s", err)
  161. }
  162. var logins = make([]string, len(records))
  163. for i, record := range records {
  164. logins[i] = record.Login
  165. }
  166. return logins, nil
  167. }
  168. type UserInfoJson struct {
  169. Name string `json:"name"`
  170. DisplayName string `json:"display_name"`
  171. Login string `json:"login"`
  172. Username string `json:"username"`
  173. Email string `json:"email"`
  174. Upn string `json:"upn"`
  175. Attributes map[string][]string `json:"attributes"`
  176. }
  177. func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
  178. var data UserInfoJson
  179. var rawUserInfoResponse HttpGetResponse
  180. var err error
  181. if !s.extractToken(&data, token) {
  182. rawUserInfoResponse, err = HttpGet(client, s.apiUrl)
  183. if err != nil {
  184. return nil, fmt.Errorf("Error getting user info: %s", err)
  185. }
  186. err = json.Unmarshal(rawUserInfoResponse.Body, &data)
  187. if err != nil {
  188. return nil, fmt.Errorf("Error decoding user info JSON: %s", err)
  189. }
  190. }
  191. name := s.extractName(&data)
  192. email := s.extractEmail(&data, rawUserInfoResponse.Body)
  193. if email == "" {
  194. email, err = s.FetchPrivateEmail(client)
  195. if err != nil {
  196. return nil, err
  197. }
  198. }
  199. login := s.extractLogin(&data, email)
  200. userInfo := &BasicUserInfo{
  201. Name: name,
  202. Login: login,
  203. Email: email,
  204. }
  205. if !s.IsTeamMember(client) {
  206. return nil, errors.New("User not a member of one of the required teams")
  207. }
  208. if !s.IsOrganizationMember(client) {
  209. return nil, errors.New("User not a member of one of the required organizations")
  210. }
  211. return userInfo, nil
  212. }
  213. func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool {
  214. idToken := token.Extra("id_token")
  215. if idToken == nil {
  216. s.log.Debug("No id_token found", "token", token)
  217. return false
  218. }
  219. jwtRegexp := regexp.MustCompile("^([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)[.]([-_a-zA-Z0-9=]+)$")
  220. matched := jwtRegexp.FindStringSubmatch(idToken.(string))
  221. if matched == nil {
  222. s.log.Debug("id_token is not in JWT format", "id_token", idToken.(string))
  223. return false
  224. }
  225. payload, err := base64.RawURLEncoding.DecodeString(matched[2])
  226. if err != nil {
  227. s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "err", err)
  228. return false
  229. }
  230. err = json.Unmarshal(payload, data)
  231. if err != nil {
  232. s.log.Error("Error decoding id_token JSON", "payload", string(payload), "err", err)
  233. return false
  234. }
  235. if email := s.extractEmail(data, payload); email == "" {
  236. s.log.Debug("No email found in id_token", "json", string(payload), "data", data)
  237. return false
  238. }
  239. s.log.Debug("Received id_token", "json", string(payload), "data", data)
  240. return true
  241. }
  242. func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byte) string {
  243. if data.Email != "" {
  244. return data.Email
  245. }
  246. if s.emailAttributePath != "" {
  247. email := s.searchJSONForEmail(userInfoResp)
  248. if email != "" {
  249. return email
  250. }
  251. }
  252. emails, ok := data.Attributes[s.emailAttributeName]
  253. if ok && len(emails) != 0 {
  254. return emails[0]
  255. }
  256. if data.Upn != "" {
  257. emailAddr, emailErr := mail.ParseAddress(data.Upn)
  258. if emailErr == nil {
  259. return emailAddr.Address
  260. }
  261. s.log.Debug("Failed to parse e-mail address", "err", emailErr.Error())
  262. }
  263. return ""
  264. }
  265. func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) string {
  266. if data.Login != "" {
  267. return data.Login
  268. }
  269. if data.Username != "" {
  270. return data.Username
  271. }
  272. return email
  273. }
  274. func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
  275. if data.Name != "" {
  276. return data.Name
  277. }
  278. if data.DisplayName != "" {
  279. return data.DisplayName
  280. }
  281. return ""
  282. }