user_auth.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package sqlstore
  2. import (
  3. "encoding/base64"
  4. "time"
  5. "github.com/grafana/grafana/pkg/bus"
  6. m "github.com/grafana/grafana/pkg/models"
  7. "github.com/grafana/grafana/pkg/setting"
  8. "github.com/grafana/grafana/pkg/util"
  9. )
  10. func init() {
  11. bus.AddHandler("sql", GetUserByAuthInfo)
  12. bus.AddHandler("sql", GetAuthInfo)
  13. bus.AddHandler("sql", SetAuthInfo)
  14. bus.AddHandler("sql", UpdateAuthInfo)
  15. bus.AddHandler("sql", DeleteAuthInfo)
  16. }
  17. func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
  18. user := &m.User{}
  19. has := false
  20. var err error
  21. authQuery := &m.GetAuthInfoQuery{}
  22. // Try to find the user by auth module and id first
  23. if query.AuthModule != "" && query.AuthId != "" {
  24. authQuery.AuthModule = query.AuthModule
  25. authQuery.AuthId = query.AuthId
  26. err = GetAuthInfo(authQuery)
  27. if err != m.ErrUserNotFound {
  28. if err != nil {
  29. return err
  30. }
  31. // if user id was specified and doesn't match the user_auth entry, remove it
  32. if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
  33. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  34. UserAuth: authQuery.Result,
  35. })
  36. if err != nil {
  37. sqlog.Error("Error removing user_auth entry", "error", err)
  38. }
  39. authQuery.Result = nil
  40. } else {
  41. has, err = x.Id(authQuery.Result.UserId).Get(user)
  42. if err != nil {
  43. return err
  44. }
  45. if !has {
  46. // if the user has been deleted then remove the entry
  47. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  48. UserAuth: authQuery.Result,
  49. })
  50. if err != nil {
  51. sqlog.Error("Error removing user_auth entry", "error", err)
  52. }
  53. authQuery.Result = nil
  54. }
  55. }
  56. }
  57. }
  58. // If not found, try to find the user by id
  59. if !has && query.UserId != 0 {
  60. has, err = x.Id(query.UserId).Get(user)
  61. if err != nil {
  62. return err
  63. }
  64. }
  65. // If not found, try to find the user by email address
  66. if !has && query.Email != "" {
  67. user = &m.User{Email: query.Email}
  68. has, err = x.Get(user)
  69. if err != nil {
  70. return err
  71. }
  72. }
  73. // If not found, try to find the user by login
  74. if !has && query.Login != "" {
  75. user = &m.User{Login: query.Login}
  76. has, err = x.Get(user)
  77. if err != nil {
  78. return err
  79. }
  80. }
  81. // No user found
  82. if !has {
  83. return m.ErrUserNotFound
  84. }
  85. // create authInfo record to link accounts
  86. if authQuery.Result == nil && query.AuthModule != "" {
  87. cmd2 := &m.SetAuthInfoCommand{
  88. UserId: user.Id,
  89. AuthModule: query.AuthModule,
  90. AuthId: query.AuthId,
  91. }
  92. if err := SetAuthInfo(cmd2); err != nil {
  93. return err
  94. }
  95. }
  96. query.Result = user
  97. return nil
  98. }
  99. func GetAuthInfo(query *m.GetAuthInfoQuery) error {
  100. userAuth := &m.UserAuth{
  101. UserId: query.UserId,
  102. AuthModule: query.AuthModule,
  103. AuthId: query.AuthId,
  104. }
  105. has, err := x.Desc("created").Get(userAuth)
  106. if err != nil {
  107. return err
  108. }
  109. if !has {
  110. return m.ErrUserNotFound
  111. }
  112. secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
  113. if err != nil {
  114. return err
  115. }
  116. secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
  117. if err != nil {
  118. return err
  119. }
  120. secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
  121. if err != nil {
  122. return err
  123. }
  124. userAuth.OAuthAccessToken = secretAccessToken
  125. userAuth.OAuthRefreshToken = secretRefreshToken
  126. userAuth.OAuthTokenType = secretTokenType
  127. query.Result = userAuth
  128. return nil
  129. }
  130. func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
  131. return inTransaction(func(sess *DBSession) error {
  132. authUser := &m.UserAuth{
  133. UserId: cmd.UserId,
  134. AuthModule: cmd.AuthModule,
  135. AuthId: cmd.AuthId,
  136. Created: time.Now(),
  137. }
  138. if cmd.OAuthToken != nil {
  139. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  140. if err != nil {
  141. return err
  142. }
  143. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  144. if err != nil {
  145. return err
  146. }
  147. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  148. if err != nil {
  149. return err
  150. }
  151. authUser.OAuthAccessToken = secretAccessToken
  152. authUser.OAuthRefreshToken = secretRefreshToken
  153. authUser.OAuthTokenType = secretTokenType
  154. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  155. }
  156. _, err := sess.Insert(authUser)
  157. return err
  158. })
  159. }
  160. func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
  161. return inTransaction(func(sess *DBSession) error {
  162. authUser := &m.UserAuth{
  163. UserId: cmd.UserId,
  164. AuthModule: cmd.AuthModule,
  165. AuthId: cmd.AuthId,
  166. Created: time.Now(),
  167. }
  168. if cmd.OAuthToken != nil {
  169. secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
  170. if err != nil {
  171. return err
  172. }
  173. secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
  174. if err != nil {
  175. return err
  176. }
  177. secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
  178. if err != nil {
  179. return err
  180. }
  181. authUser.OAuthAccessToken = secretAccessToken
  182. authUser.OAuthRefreshToken = secretRefreshToken
  183. authUser.OAuthTokenType = secretTokenType
  184. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  185. }
  186. cond := &m.UserAuth{
  187. UserId: cmd.UserId,
  188. AuthModule: cmd.AuthModule,
  189. }
  190. _, err := sess.Update(authUser, cond)
  191. return err
  192. })
  193. }
  194. func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
  195. return inTransaction(func(sess *DBSession) error {
  196. _, err := sess.Delete(cmd.UserAuth)
  197. return err
  198. })
  199. }
  200. // decodeAndDecrypt will decode the string with the standard bas64 decoder
  201. // and then decrypt it with grafana's secretKey
  202. func decodeAndDecrypt(s string) (string, error) {
  203. // Bail out if empty string since it'll cause a segfault in util.Decrypt
  204. if s == "" {
  205. return "", nil
  206. }
  207. decoded, err := base64.StdEncoding.DecodeString(s)
  208. if err != nil {
  209. return "", err
  210. }
  211. decrypted, err := util.Decrypt(decoded, setting.SecretKey)
  212. if err != nil {
  213. return "", err
  214. }
  215. return string(decrypted), nil
  216. }
  217. // encryptAndEncode will encrypt a string with grafana's secretKey, and
  218. // then encode it with the standard bas64 encoder
  219. func encryptAndEncode(s string) (string, error) {
  220. encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
  221. if err != nil {
  222. return "", err
  223. }
  224. return base64.StdEncoding.EncodeToString(encrypted), nil
  225. }