user_auth.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package sqlstore
  2. import (
  3. "encoding/base64"
  4. "time"
  5. "github.com/grafana/grafana/pkg/bus"
  6. m "github.com/grafana/grafana/pkg/models"
  7. "github.com/grafana/grafana/pkg/setting"
  8. "github.com/grafana/grafana/pkg/util"
  9. )
  10. func init() {
  11. bus.AddHandler("sql", GetUserByAuthInfo)
  12. bus.AddHandler("sql", GetAuthInfo)
  13. bus.AddHandler("sql", SetAuthInfo)
  14. bus.AddHandler("sql", UpdateAuthInfo)
  15. bus.AddHandler("sql", DeleteAuthInfo)
  16. }
  17. func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
  18. user := &m.User{}
  19. has := false
  20. var err error
  21. authQuery := &m.GetAuthInfoQuery{}
  22. // Try to find the user by auth module and id first
  23. if query.AuthModule != "" && query.AuthId != "" {
  24. authQuery.AuthModule = query.AuthModule
  25. authQuery.AuthId = query.AuthId
  26. err = GetAuthInfo(authQuery)
  27. if err != m.ErrUserNotFound {
  28. if err != nil {
  29. return err
  30. }
  31. // if user id was specified and doesn't match the user_auth entry, remove it
  32. if query.UserId != 0 && query.UserId != authQuery.Result.UserId {
  33. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  34. UserAuth: authQuery.Result,
  35. })
  36. if err != nil {
  37. sqlog.Error("Error removing user_auth entry", "error", err)
  38. }
  39. authQuery.Result = nil
  40. } else {
  41. has, err = x.Id(authQuery.Result.UserId).Get(user)
  42. if err != nil {
  43. return err
  44. }
  45. if !has {
  46. // if the user has been deleted then remove the entry
  47. err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{
  48. UserAuth: authQuery.Result,
  49. })
  50. if err != nil {
  51. sqlog.Error("Error removing user_auth entry", "error", err)
  52. }
  53. authQuery.Result = nil
  54. }
  55. }
  56. }
  57. }
  58. // If not found, try to find the user by id
  59. if !has && query.UserId != 0 {
  60. has, err = x.Id(query.UserId).Get(user)
  61. if err != nil {
  62. return err
  63. }
  64. }
  65. // If not found, try to find the user by email address
  66. if !has && query.Email != "" {
  67. user = &m.User{Email: query.Email}
  68. has, err = x.Get(user)
  69. if err != nil {
  70. return err
  71. }
  72. }
  73. // If not found, try to find the user by login
  74. if !has && query.Login != "" {
  75. user = &m.User{Login: query.Login}
  76. has, err = x.Get(user)
  77. if err != nil {
  78. return err
  79. }
  80. }
  81. // No user found
  82. if !has {
  83. return m.ErrUserNotFound
  84. }
  85. // create authInfo record to link accounts
  86. if authQuery.Result == nil && query.AuthModule != "" {
  87. cmd2 := &m.SetAuthInfoCommand{
  88. UserId: user.Id,
  89. AuthModule: query.AuthModule,
  90. AuthId: query.AuthId,
  91. }
  92. if err := SetAuthInfo(cmd2); err != nil {
  93. return err
  94. }
  95. }
  96. query.Result = user
  97. return nil
  98. }
  99. func GetAuthInfo(query *m.GetAuthInfoQuery) error {
  100. userAuth := &m.UserAuth{
  101. UserId: query.UserId, // TODO this doesn't have an index in the db
  102. AuthModule: query.AuthModule,
  103. AuthId: query.AuthId,
  104. }
  105. has, err := x.Get(userAuth)
  106. if err != nil {
  107. return err
  108. }
  109. if !has {
  110. return m.ErrUserNotFound
  111. }
  112. if userAuth.OAuthAccessToken != "" {
  113. decodedAccessToken, err := base64.StdEncoding.DecodeString(userAuth.OAuthAccessToken)
  114. if err != nil {
  115. return err
  116. }
  117. decryptedAccessToken, err := util.Decrypt(decodedAccessToken, setting.SecretKey)
  118. if err != nil {
  119. return err
  120. }
  121. userAuth.OAuthAccessToken = string(decryptedAccessToken)
  122. }
  123. if userAuth.OAuthRefreshToken != "" {
  124. decodedRefreshToken, err := base64.StdEncoding.DecodeString(userAuth.OAuthRefreshToken)
  125. if err != nil {
  126. return err
  127. }
  128. decryptedRefreshToken, err := util.Decrypt(decodedRefreshToken, setting.SecretKey)
  129. if err != nil {
  130. return err
  131. }
  132. userAuth.OAuthRefreshToken = string(decryptedRefreshToken)
  133. }
  134. if userAuth.OAuthTokenType != "" {
  135. decodedTokenType, err := base64.StdEncoding.DecodeString(userAuth.OAuthTokenType)
  136. if err != nil {
  137. return err
  138. }
  139. decryptedTokenType, err := util.Decrypt(decodedTokenType, setting.SecretKey)
  140. if err != nil {
  141. return err
  142. }
  143. userAuth.OAuthTokenType = string(decryptedTokenType)
  144. }
  145. query.Result = userAuth
  146. return nil
  147. }
  148. func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
  149. return inTransaction(func(sess *DBSession) error {
  150. authUser := &m.UserAuth{
  151. UserId: cmd.UserId,
  152. AuthModule: cmd.AuthModule,
  153. AuthId: cmd.AuthId,
  154. Created: time.Now(),
  155. }
  156. if cmd.OAuthToken != nil {
  157. secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
  158. if err != nil {
  159. return err
  160. }
  161. secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
  162. if err != nil {
  163. return err
  164. }
  165. secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
  166. if err != nil {
  167. return err
  168. }
  169. authUser.OAuthAccessToken = base64.StdEncoding.EncodeToString(secretAccessToken)
  170. authUser.OAuthRefreshToken = base64.StdEncoding.EncodeToString(secretRefreshToken)
  171. authUser.OAuthTokenType = base64.StdEncoding.EncodeToString(secretTokenType)
  172. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  173. }
  174. _, err := sess.Insert(authUser)
  175. return err
  176. })
  177. }
  178. func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
  179. return inTransaction(func(sess *DBSession) error {
  180. authUser := &m.UserAuth{
  181. UserId: cmd.UserId,
  182. AuthModule: cmd.AuthModule,
  183. AuthId: cmd.AuthId,
  184. Created: time.Now(),
  185. }
  186. if cmd.OAuthToken != nil {
  187. secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
  188. if err != nil {
  189. return err
  190. }
  191. secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
  192. if err != nil {
  193. return err
  194. }
  195. secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
  196. if err != nil {
  197. return err
  198. }
  199. authUser.OAuthAccessToken = base64.StdEncoding.EncodeToString(secretAccessToken)
  200. authUser.OAuthRefreshToken = base64.StdEncoding.EncodeToString(secretRefreshToken)
  201. authUser.OAuthTokenType = base64.StdEncoding.EncodeToString(secretTokenType)
  202. authUser.OAuthExpiry = cmd.OAuthToken.Expiry
  203. }
  204. cond := &m.UserAuth{
  205. UserId: cmd.UserId,
  206. AuthModule: cmd.AuthModule,
  207. }
  208. _, err := sess.Update(authUser, cond)
  209. return err
  210. })
  211. }
  212. func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
  213. return inTransaction(func(sess *DBSession) error {
  214. _, err := sess.Delete(cmd.UserAuth)
  215. return err
  216. })
  217. }