|
@@ -3,9 +3,11 @@ package api
|
|
|
import (
|
|
import (
|
|
|
"context"
|
|
"context"
|
|
|
"crypto/rand"
|
|
"crypto/rand"
|
|
|
|
|
+ "crypto/sha256"
|
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
"crypto/x509"
|
|
|
"encoding/base64"
|
|
"encoding/base64"
|
|
|
|
|
+ "encoding/hex"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
|
"net/http"
|
|
"net/http"
|
|
@@ -18,12 +20,14 @@ import (
|
|
|
"github.com/grafana/grafana/pkg/login"
|
|
"github.com/grafana/grafana/pkg/login"
|
|
|
"github.com/grafana/grafana/pkg/metrics"
|
|
"github.com/grafana/grafana/pkg/metrics"
|
|
|
m "github.com/grafana/grafana/pkg/models"
|
|
m "github.com/grafana/grafana/pkg/models"
|
|
|
- "github.com/grafana/grafana/pkg/services/session"
|
|
|
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
|
"github.com/grafana/grafana/pkg/social"
|
|
"github.com/grafana/grafana/pkg/social"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-var oauthLogger = log.New("oauth")
|
|
|
|
|
|
|
+var (
|
|
|
|
|
+ oauthLogger = log.New("oauth")
|
|
|
|
|
+ OauthStateCookieName = "oauth_state"
|
|
|
|
|
+)
|
|
|
|
|
|
|
|
func GenStateString() string {
|
|
func GenStateString() string {
|
|
|
rnd := make([]byte, 32)
|
|
rnd := make([]byte, 32)
|
|
@@ -55,7 +59,9 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
|
|
|
code := ctx.Query("code")
|
|
code := ctx.Query("code")
|
|
|
if code == "" {
|
|
if code == "" {
|
|
|
state := GenStateString()
|
|
state := GenStateString()
|
|
|
- ctx.Session.Set(session.SESS_KEY_OAUTH_STATE, state)
|
|
|
|
|
|
|
+ hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret)
|
|
|
|
|
+ hs.writeOauthStateCookie(ctx, hashedState, 60)
|
|
|
|
|
+
|
|
|
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
|
|
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
|
|
|
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
|
|
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
|
|
|
} else {
|
|
} else {
|
|
@@ -64,13 +70,18 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- savedState, ok := ctx.Session.Get(session.SESS_KEY_OAUTH_STATE).(string)
|
|
|
|
|
- if !ok {
|
|
|
|
|
|
|
+ savedState := ctx.GetCookie(OauthStateCookieName)
|
|
|
|
|
+
|
|
|
|
|
+ // delete cookie
|
|
|
|
|
+ ctx.Resp.Header().Del("Set-Cookie")
|
|
|
|
|
+ hs.writeOauthStateCookie(ctx, "", -1)
|
|
|
|
|
+
|
|
|
|
|
+ if savedState == "" {
|
|
|
ctx.Handle(500, "login.OAuthLogin(missing saved state)", nil)
|
|
ctx.Handle(500, "login.OAuthLogin(missing saved state)", nil)
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- queryState := ctx.Query("state")
|
|
|
|
|
|
|
+ queryState := hashStatecode(ctx.Query("state"), setting.OAuthService.OAuthInfos[name].ClientSecret)
|
|
|
if savedState != queryState {
|
|
if savedState != queryState {
|
|
|
ctx.Handle(500, "login.OAuthLogin(state mismatch)", nil)
|
|
ctx.Handle(500, "login.OAuthLogin(state mismatch)", nil)
|
|
|
return
|
|
return
|
|
@@ -191,6 +202,22 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
|
|
|
ctx.Redirect(setting.AppSubUrl + "/")
|
|
ctx.Redirect(setting.AppSubUrl + "/")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (hs *HTTPServer) writeOauthStateCookie(ctx *m.ReqContext, value string, maxAge int) {
|
|
|
|
|
+ http.SetCookie(ctx.Resp, &http.Cookie{
|
|
|
|
|
+ Name: OauthStateCookieName,
|
|
|
|
|
+ MaxAge: maxAge,
|
|
|
|
|
+ Value: value,
|
|
|
|
|
+ HttpOnly: true,
|
|
|
|
|
+ Path: setting.AppSubUrl + "/",
|
|
|
|
|
+ Secure: hs.Cfg.LoginCookieSecure,
|
|
|
|
|
+ })
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func hashStatecode(code, seed string) string {
|
|
|
|
|
+ hashBytes := sha256.Sum256([]byte(code + setting.SecretKey + seed))
|
|
|
|
|
+ return hex.EncodeToString(hashBytes[:])
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func redirectWithError(ctx *m.ReqContext, err error, v ...interface{}) {
|
|
func redirectWithError(ctx *m.ReqContext, err error, v ...interface{}) {
|
|
|
ctx.Logger.Error(err.Error(), v...)
|
|
ctx.Logger.Error(err.Error(), v...)
|
|
|
ctx.Session.Set("loginError", err.Error())
|
|
ctx.Session.Set("loginError", err.Error())
|