user_auth.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. package sqlstore
  2. import (
  3. "encoding/base64"
  4. "time"
  5. "github.com/grafana/grafana/pkg/bus"
  6. "github.com/grafana/grafana/pkg/models"
  7. "github.com/grafana/grafana/pkg/setting"
  8. "github.com/grafana/grafana/pkg/util"
  9. )
  10. var getTime = time.Now
  11. func init() {
  12. bus.AddHandler("sql", GetUserByAuthInfo)
  13. bus.AddHandler("sql", GetExternalUserInfoByLogin)
  14. bus.AddHandler("sql", GetAuthInfo)
  15. bus.AddHandler("sql", SetAuthInfo)
  16. bus.AddHandler("sql", UpdateAuthInfo)
  17. bus.AddHandler("sql", DeleteAuthInfo)
  18. }
  19. func GetUserByAuthInfo(query *models.GetUserByAuthInfoQuery) error {
  20. user := &models.User{}
  21. has := false
  22. var err error
  23. authQuery := &models.GetAuthInfoQuery{}
  24. // Try to find the user by auth module and id first
  25. if query.AuthModule != "" && query.AuthId != "" {
  26. authQuery.AuthModule = query.AuthModule
  27. authQuery.AuthId = query.AuthId
  28. err = GetAuthInfo(authQuery)
  29. if err != models.ErrUserNotFound {
  30. if err != nil {
  31. return err
  32. }
  33. // if user id was specified and doesn't match the user_auth entry, remove it
  34. if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
  35. err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
  36. UserAuth: authQuery.Result,
  37. })
  38. if err != nil {
  39. sqlog.Error("Error removing user_auth entry", "error", err)
  40. }
  41. authQuery.Result = nil
  42. } else {
  43. has, err = x.Id(authQuery.Result.UserId).Get(user)
  44. if err != nil {
  45. return err
  46. }
  47. if !has {
  48. // if the user has been deleted then remove the entry
  49. err = DeleteAuthInfo(&models.DeleteAuthInfoCommand{
  50. UserAuth: authQuery.Result,
  51. })
  52. if err != nil {
  53. sqlog.Error("Error removing user_auth entry", "error", err)
  54. }
  55. authQuery.Result = nil
  56. }
  57. }
  58. }
  59. }
  60. // If not found, try to find the user by id
  61. if !has && query.UserId != 0 {
  62. has, err = x.Id(query.UserId).Get(user)
  63. if err != nil {
  64. return err
  65. }
  66. }
  67. // If not found, try to find the user by email address
  68. if !has && query.Email != "" {
  69. user = &models.User{Email: query.Email}
  70. has, err = x.Get(user)
  71. if err != nil {
  72. return err
  73. }
  74. }
  75. // If not found, try to find the user by login
  76. if !has && query.Login != "" {
  77. user = &models.User{Login: query.Login}
  78. has, err = x.Get(user)
  79. if err != nil {
  80. return err
  81. }
  82. }
  83. // No user found
  84. if !has {
  85. return models.ErrUserNotFound
  86. }
  87. // create authInfo record to link accounts
  88. if authQuery.Result == nil && query.AuthModule != "" {
  89. cmd2 := &models.SetAuthInfoCommand{
  90. UserId: user.Id,
  91. AuthModule: query.AuthModule,
  92. AuthId: query.AuthId,
  93. }
  94. if err := SetAuthInfo(cmd2); err != nil {
  95. return err
  96. }
  97. }
  98. query.Result = user
  99. return nil
  100. }
  101. func GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
  102. userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
  103. err := bus.Dispatch(&userQuery)
  104. if err != nil {
  105. return err
  106. }
  107. authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id}
  108. if err := bus.Dispatch(authInfoQuery); err != nil {
  109. return err
  110. }
  111. query.Result = &models.ExternalUserInfo{
  112. UserId: userQuery.Result.Id,
  113. Login: userQuery.Result.Login,
  114. Email: userQuery.Result.Email,
  115. Name: userQuery.Result.Name,
  116. IsDisabled: userQuery.Result.IsDisabled,
  117. AuthModule: authInfoQuery.Result.AuthModule,
  118. AuthId: authInfoQuery.Result.AuthId,
  119. }
  120. return nil
  121. }
  122. func GetAuthInfo(query *models.GetAuthInfoQuery) error {
  123. userAuth := &models.UserAuth{
  124. UserId: query.UserId,
  125. AuthModule: query.AuthModule,
  126. AuthId: query.AuthId,
  127. }
  128. has, err := x.Desc("created").Get(userAuth)
  129. if err != nil {
  130. return err
  131. }
  132. if !has {
  133. return models.ErrUserNotFound
  134. }
  135. secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
  136. if err != nil {
  137. return err
  138. }
  139. secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
  140. if err != nil {
  141. return err
  142. }
  143. secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
  144. if err != nil {
  145. return err
  146. }
  147. userAuth.OAuthAccessToken = secretAccessToken
  148. userAuth.OAuthRefreshToken = secretRefreshToken
  149. userAuth.OAuthTokenType = secretTokenType
  150. query.Result = userAuth
  151. return nil
  152. }
  153. func SetAuthInfo(cmd *models.SetAuthInfoCommand) error {
  154. return inTransaction(func(sess *DBSession) error {
  155. authUser := &models.UserAuth{
  156. UserId: cmd.UserId,
  157. AuthModule: cmd.AuthModule,
  158. AuthId: cmd.AuthId,
  159. Created: getTime(),
  160. }
  161. if cmd.OAuthToken != nil {
  162. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  163. if err != nil {
  164. return err
  165. }
  166. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  167. if err != nil {
  168. return err
  169. }
  170. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  171. if err != nil {
  172. return err
  173. }
  174. authUser.OAuthAccessToken = secretAccessToken
  175. authUser.OAuthRefreshToken = secretRefreshToken
  176. authUser.OAuthTokenType = secretTokenType
  177. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  178. }
  179. _, err := sess.Insert(authUser)
  180. return err
  181. })
  182. }
  183. func UpdateAuthInfo(cmd *models.UpdateAuthInfoCommand) error {
  184. return inTransaction(func(sess *DBSession) error {
  185. authUser := &models.UserAuth{
  186. UserId: cmd.UserId,
  187. AuthModule: cmd.AuthModule,
  188. AuthId: cmd.AuthId,
  189. Created: getTime(),
  190. }
  191. if cmd.OAuthToken != nil {
  192. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  193. if err != nil {
  194. return err
  195. }
  196. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  197. if err != nil {
  198. return err
  199. }
  200. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  201. if err != nil {
  202. return err
  203. }
  204. authUser.OAuthAccessToken = secretAccessToken
  205. authUser.OAuthRefreshToken = secretRefreshToken
  206. authUser.OAuthTokenType = secretTokenType
  207. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  208. }
  209. cond := &models.UserAuth{
  210. UserId: cmd.UserId,
  211. AuthModule: cmd.AuthModule,
  212. }
  213. _, err := sess.Update(authUser, cond)
  214. return err
  215. })
  216. }
  217. func DeleteAuthInfo(cmd *models.DeleteAuthInfoCommand) error {
  218. return inTransaction(func(sess *DBSession) error {
  219. _, err := sess.Delete(cmd.UserAuth)
  220. return err
  221. })
  222. }
  223. // decodeAndDecrypt will decode the string with the standard bas64 decoder
  224. // and then decrypt it with grafana's secretKey
  225. func decodeAndDecrypt(s string) (string, error) {
  226. // Bail out if empty string since it'll cause a segfault in util.Decrypt
  227. if s == "" {
  228. return "", nil
  229. }
  230. decoded, err := base64.StdEncoding.DecodeString(s)
  231. if err != nil {
  232. return "", err
  233. }
  234. decrypted, err := util.Decrypt(decoded, setting.SecretKey)
  235. if err != nil {
  236. return "", err
  237. }
  238. return string(decrypted), nil
  239. }
  240. // encryptAndEncode will encrypt a string with grafana's secretKey, and
  241. // then encode it with the standard bas64 encoder
  242. func encryptAndEncode(s string) (string, error) {
  243. encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
  244. if err != nil {
  245. return "", err
  246. }
  247. return base64.StdEncoding.EncodeToString(encrypted), nil
  248. }