login_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package api
  2. import (
  3. "encoding/hex"
  4. "encoding/json"
  5. "errors"
  6. "io/ioutil"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "github.com/grafana/grafana/pkg/api/dtos"
  12. "github.com/grafana/grafana/pkg/models"
  13. "github.com/grafana/grafana/pkg/setting"
  14. "github.com/grafana/grafana/pkg/util"
  15. "github.com/stretchr/testify/assert"
  16. )
  17. func mockSetIndexViewData() {
  18. setIndexViewData = func(*HTTPServer, *models.ReqContext) (*dtos.IndexViewData, error) {
  19. data := &dtos.IndexViewData{
  20. User: &dtos.CurrentUser{},
  21. Settings: map[string]interface{}{},
  22. NavTree: []*dtos.NavLink{},
  23. }
  24. return data, nil
  25. }
  26. }
  27. func resetSetIndexViewData() {
  28. setIndexViewData = (*HTTPServer).setIndexViewData
  29. }
  30. func mockViewIndex() {
  31. getViewIndex = func() string {
  32. return "index-template"
  33. }
  34. }
  35. func resetViewIndex() {
  36. getViewIndex = func() string {
  37. return ViewIndex
  38. }
  39. }
  40. func getBody(resp *httptest.ResponseRecorder) (string, error) {
  41. responseData, err := ioutil.ReadAll(resp.Body)
  42. if err != nil {
  43. return "", err
  44. }
  45. return string(responseData), nil
  46. }
  47. func getJSONbody(resp *httptest.ResponseRecorder) (interface{}, error) {
  48. var j interface{}
  49. err := json.Unmarshal(resp.Body.Bytes(), &j)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return j, nil
  54. }
  55. func TestLoginErrorCookieApiEndpoint(t *testing.T) {
  56. mockSetIndexViewData()
  57. defer resetSetIndexViewData()
  58. mockViewIndex()
  59. defer resetViewIndex()
  60. sc := setupScenarioContext("/login")
  61. hs := &HTTPServer{
  62. Cfg: setting.NewCfg(),
  63. }
  64. sc.defaultHandler = Wrap(func(w http.ResponseWriter, c *models.ReqContext) {
  65. hs.LoginView(c)
  66. })
  67. setting.OAuthService = &setting.OAuther{}
  68. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  69. setting.LoginCookieName = "grafana_session"
  70. setting.SecretKey = "login_testing"
  71. setting.OAuthService = &setting.OAuther{}
  72. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  73. setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
  74. ClientId: "fake",
  75. ClientSecret: "fakefake",
  76. Enabled: true,
  77. AllowSignup: true,
  78. Name: "github",
  79. }
  80. setting.OAuthAutoLogin = true
  81. oauthError := errors.New("User not a member of one of the required organizations")
  82. encryptedError, _ := util.Encrypt([]byte(oauthError.Error()), setting.SecretKey)
  83. cookie := http.Cookie{
  84. Name: LoginErrorCookieName,
  85. MaxAge: 60,
  86. Value: hex.EncodeToString(encryptedError),
  87. HttpOnly: true,
  88. Path: setting.AppSubUrl + "/",
  89. Secure: hs.Cfg.CookieSecure,
  90. SameSite: hs.Cfg.CookieSameSite,
  91. }
  92. sc.m.Get(sc.url, sc.defaultHandler)
  93. sc.fakeReqNoAssertionsWithCookie("GET", sc.url, cookie).exec()
  94. assert.Equal(t, sc.resp.Code, 200)
  95. responseString, err := getBody(sc.resp)
  96. assert.Nil(t, err)
  97. assert.True(t, strings.Contains(responseString, oauthError.Error()))
  98. }
  99. func TestLoginOAuthRedirect(t *testing.T) {
  100. mockSetIndexViewData()
  101. defer resetSetIndexViewData()
  102. sc := setupScenarioContext("/login")
  103. hs := &HTTPServer{
  104. Cfg: setting.NewCfg(),
  105. }
  106. sc.defaultHandler = Wrap(func(c *models.ReqContext) {
  107. hs.LoginView(c)
  108. })
  109. setting.OAuthService = &setting.OAuther{}
  110. setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
  111. setting.OAuthService.OAuthInfos["github"] = &setting.OAuthInfo{
  112. ClientId: "fake",
  113. ClientSecret: "fakefake",
  114. Enabled: true,
  115. AllowSignup: true,
  116. Name: "github",
  117. }
  118. setting.OAuthAutoLogin = true
  119. sc.m.Get(sc.url, sc.defaultHandler)
  120. sc.fakeReqNoAssertions("GET", sc.url).exec()
  121. assert.Equal(t, sc.resp.Code, 307)
  122. location, ok := sc.resp.Header()["Location"]
  123. assert.True(t, ok)
  124. assert.Equal(t, location[0], "/login/github")
  125. }