login_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package api
  2. import (
  3. "encoding/hex"
  4. "errors"
  5. "github.com/grafana/grafana/pkg/api/dtos"
  6. "github.com/grafana/grafana/pkg/models"
  7. "github.com/grafana/grafana/pkg/setting"
  8. "github.com/grafana/grafana/pkg/util"
  9. "github.com/stretchr/testify/assert"
  10. "io/ioutil"
  11. "net/http"
  12. "net/http/httptest"
  13. "strings"
  14. "testing"
  15. )
  16. func mockSetIndexViewData() {
  17. setIndexViewData = func(*HTTPServer, *models.ReqContext) (*dtos.IndexViewData, error) {
  18. data := &dtos.IndexViewData{
  19. User: &dtos.CurrentUser{},
  20. Settings: map[string]interface{}{},
  21. NavTree: []*dtos.NavLink{},
  22. }
  23. return data, nil
  24. }
  25. }
  26. func resetSetIndexViewData() {
  27. setIndexViewData = (*HTTPServer).setIndexViewData
  28. }
  29. func mockViewIndex() {
  30. getViewIndex = func() string {
  31. return "index-template"
  32. }
  33. }
  34. func resetViewIndex() {
  35. getViewIndex = func() string {
  36. return ViewIndex
  37. }
  38. }
  39. func getBody(resp *httptest.ResponseRecorder) (string, error) {
  40. responseData, err := ioutil.ReadAll(resp.Body)
  41. if err != nil {
  42. return "", err
  43. }
  44. return string(responseData), nil
  45. }
  46. func TestLoginErrorCookieApiEndpoint(t *testing.T) {
  47. mockSetIndexViewData()
  48. defer resetSetIndexViewData()
  49. mockViewIndex()
  50. defer resetViewIndex()
  51. sc := setupScenarioContext("/login")
  52. hs := &HTTPServer{
  53. Cfg: setting.NewCfg(),
  54. }
  55. sc.defaultHandler = Wrap(func(w http.ResponseWriter, c *models.ReqContext) {
  56. hs.LoginView(c)
  57. })
  58. setting.OAuthService = &setting.OAuther{}
  59. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  60. setting.LoginCookieName = "grafana_session"
  61. setting.SecretKey = "login_testing"
  62. setting.OAuthService = &setting.OAuther{}
  63. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  64. setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
  65. ClientId: "fake",
  66. ClientSecret: "fakefake",
  67. Enabled: true,
  68. AllowSignup: true,
  69. Name: "github",
  70. }
  71. setting.OAuthAutoLogin = true
  72. oauthError := errors.New("User not a member of one of the required organizations")
  73. encryptedError, _ := util.Encrypt([]byte(oauthError.Error()), setting.SecretKey)
  74. cookie := http.Cookie{
  75. Name: LoginErrorCookieName,
  76. MaxAge: 60,
  77. Value: hex.EncodeToString(encryptedError),
  78. HttpOnly: true,
  79. Path: setting.AppSubUrl + "/",
  80. Secure: hs.Cfg.CookieSecure,
  81. SameSite: hs.Cfg.CookieSameSite,
  82. }
  83. sc.m.Get(sc.url, sc.defaultHandler)
  84. sc.fakeReqNoAssertionsWithCookie("GET", sc.url, cookie).exec()
  85. assert.Equal(t, sc.resp.Code, 200)
  86. responseString, err := getBody(sc.resp)
  87. assert.Nil(t, err)
  88. assert.True(t, strings.Contains(responseString, oauthError.Error()))
  89. }
  90. func TestLoginOAuthRedirect(t *testing.T) {
  91. mockSetIndexViewData()
  92. defer resetSetIndexViewData()
  93. sc := setupScenarioContext("/login")
  94. hs := &HTTPServer{
  95. Cfg: setting.NewCfg(),
  96. }
  97. sc.defaultHandler = Wrap(func(c *models.ReqContext) {
  98. hs.LoginView(c)
  99. })
  100. setting.OAuthService = &setting.OAuther{}
  101. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  102. setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
  103. ClientId: "fake",
  104. ClientSecret: "fakefake",
  105. Enabled: true,
  106. AllowSignup: true,
  107. Name: "github",
  108. }
  109. setting.OAuthAutoLogin = true
  110. sc.m.Get(sc.url, sc.defaultHandler)
  111. sc.fakeReqNoAssertions("GET", sc.url).exec()
  112. assert.Equal(t, sc.resp.Code, 307)
  113. location, ok := sc.resp.Header()["Location"]
  114. assert.True(t, ok)
  115. assert.Equal(t, location[0], "/login/github")
  116. }