user_auth.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package sqlstore
  2. import (
  3. "encoding/base64"
  4. "time"
  5. "github.com/grafana/grafana/pkg/bus"
  6. m "github.com/grafana/grafana/pkg/models"
  7. "github.com/grafana/grafana/pkg/setting"
  8. "github.com/grafana/grafana/pkg/util"
  9. )
  10. var getTime = time.Now
  11. func init() {
  12. bus.AddHandler("sql", GetUserByAuthInfo)
  13. bus.AddHandler("sql", GetAuthInfo)
  14. bus.AddHandler("sql", SetAuthInfo)
  15. bus.AddHandler("sql", UpdateAuthInfo)
  16. bus.AddHandler("sql", DeleteAuthInfo)
  17. }
  18. func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
  19. user := &m.User{}
  20. has := false
  21. var err error
  22. authQuery := &m.GetAuthInfoQuery{}
  23. // Try to find the user by auth module and id first
  24. if query.AuthModule != "" && query.AuthId != "" {
  25. authQuery.AuthModule = query.AuthModule
  26. authQuery.AuthId = query.AuthId
  27. err = GetAuthInfo(authQuery)
  28. if err != m.ErrUserNotFound {
  29. if err != nil {
  30. return err
  31. }
  32. // if user id was specified and doesn't match the user_auth entry, remove it
  33. if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
  34. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  35. UserAuth: authQuery.Result,
  36. })
  37. if err != nil {
  38. sqlog.Error("Error removing user_auth entry", "error", err)
  39. }
  40. authQuery.Result = nil
  41. } else {
  42. has, err = x.Id(authQuery.Result.UserId).Get(user)
  43. if err != nil {
  44. return err
  45. }
  46. if !has {
  47. // if the user has been deleted then remove the entry
  48. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  49. UserAuth: authQuery.Result,
  50. })
  51. if err != nil {
  52. sqlog.Error("Error removing user_auth entry", "error", err)
  53. }
  54. authQuery.Result = nil
  55. }
  56. }
  57. }
  58. }
  59. // If not found, try to find the user by id
  60. if !has && query.UserId != 0 {
  61. has, err = x.Id(query.UserId).Get(user)
  62. if err != nil {
  63. return err
  64. }
  65. }
  66. // If not found, try to find the user by email address
  67. if !has && query.Email != "" {
  68. user = &m.User{Email: query.Email}
  69. has, err = x.Get(user)
  70. if err != nil {
  71. return err
  72. }
  73. }
  74. // If not found, try to find the user by login
  75. if !has && query.Login != "" {
  76. user = &m.User{Login: query.Login}
  77. has, err = x.Get(user)
  78. if err != nil {
  79. return err
  80. }
  81. }
  82. // No user found
  83. if !has {
  84. return m.ErrUserNotFound
  85. }
  86. // create authInfo record to link accounts
  87. if authQuery.Result == nil && query.AuthModule != "" {
  88. cmd2 := &m.SetAuthInfoCommand{
  89. UserId: user.Id,
  90. AuthModule: query.AuthModule,
  91. AuthId: query.AuthId,
  92. }
  93. if err := SetAuthInfo(cmd2); err != nil {
  94. return err
  95. }
  96. }
  97. query.Result = user
  98. return nil
  99. }
  100. func GetAuthInfo(query *m.GetAuthInfoQuery) error {
  101. userAuth := &m.UserAuth{
  102. UserId: query.UserId,
  103. AuthModule: query.AuthModule,
  104. AuthId: query.AuthId,
  105. }
  106. has, err := x.Desc("created").Get(userAuth)
  107. if err != nil {
  108. return err
  109. }
  110. if !has {
  111. return m.ErrUserNotFound
  112. }
  113. secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
  114. if err != nil {
  115. return err
  116. }
  117. secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
  118. if err != nil {
  119. return err
  120. }
  121. secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
  122. if err != nil {
  123. return err
  124. }
  125. userAuth.OAuthAccessToken = secretAccessToken
  126. userAuth.OAuthRefreshToken = secretRefreshToken
  127. userAuth.OAuthTokenType = secretTokenType
  128. query.Result = userAuth
  129. return nil
  130. }
  131. func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
  132. return inTransaction(func(sess *DBSession) error {
  133. authUser := &m.UserAuth{
  134. UserId: cmd.UserId,
  135. AuthModule: cmd.AuthModule,
  136. AuthId: cmd.AuthId,
  137. Created: getTime(),
  138. }
  139. if cmd.OAuthToken != nil {
  140. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  141. if err != nil {
  142. return err
  143. }
  144. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  145. if err != nil {
  146. return err
  147. }
  148. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  149. if err != nil {
  150. return err
  151. }
  152. authUser.OAuthAccessToken = secretAccessToken
  153. authUser.OAuthRefreshToken = secretRefreshToken
  154. authUser.OAuthTokenType = secretTokenType
  155. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  156. }
  157. _, err := sess.Insert(authUser)
  158. return err
  159. })
  160. }
  161. func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
  162. return inTransaction(func(sess *DBSession) error {
  163. authUser := &m.UserAuth{
  164. UserId: cmd.UserId,
  165. AuthModule: cmd.AuthModule,
  166. AuthId: cmd.AuthId,
  167. Created: getTime(),
  168. }
  169. if cmd.OAuthToken != nil {
  170. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  171. if err != nil {
  172. return err
  173. }
  174. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  175. if err != nil {
  176. return err
  177. }
  178. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  179. if err != nil {
  180. return err
  181. }
  182. authUser.OAuthAccessToken = secretAccessToken
  183. authUser.OAuthRefreshToken = secretRefreshToken
  184. authUser.OAuthTokenType = secretTokenType
  185. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  186. }
  187. cond := &m.UserAuth{
  188. UserId: cmd.UserId,
  189. AuthModule: cmd.AuthModule,
  190. }
  191. _, err := sess.Update(authUser, cond)
  192. return err
  193. })
  194. }
  195. func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
  196. return inTransaction(func(sess *DBSession) error {
  197. _, err := sess.Delete(cmd.UserAuth)
  198. return err
  199. })
  200. }
  201. // decodeAndDecrypt will decode the string with the standard bas64 decoder
  202. // and then decrypt it with grafana's secretKey
  203. func decodeAndDecrypt(s string) (string, error) {
  204. // Bail out if empty string since it'll cause a segfault in util.Decrypt
  205. if s == "" {
  206. return "", nil
  207. }
  208. decoded, err := base64.StdEncoding.DecodeString(s)
  209. if err != nil {
  210. return "", err
  211. }
  212. decrypted, err := util.Decrypt(decoded, setting.SecretKey)
  213. if err != nil {
  214. return "", err
  215. }
  216. return string(decrypted), nil
  217. }
  218. // encryptAndEncode will encrypt a string with grafana's secretKey, and
  219. // then encode it with the standard bas64 encoder
  220. func encryptAndEncode(s string) (string, error) {
  221. encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
  222. if err != nil {
  223. return "", err
  224. }
  225. return base64.StdEncoding.EncodeToString(encrypted), nil
  226. }