瀏覽代碼

Add oauth pass-thru option for datasources

Sean Lafferty 6 年之前
父節點
當前提交
5a59cdf0ef

+ 1 - 0
pkg/api/login_oauth.go

@@ -165,6 +165,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
 
 	extUser := &m.ExternalUserInfo{
 		AuthModule: "oauth_" + name,
+		OAuthToken: token,
 		AuthId:     userInfo.Id,
 		Name:       userInfo.Name,
 		Login:      userInfo.Login,

+ 41 - 0
pkg/api/pluginproxy/ds_proxy.go

@@ -14,11 +14,14 @@ import (
 	"time"
 
 	"github.com/opentracing/opentracing-go"
+	"golang.org/x/oauth2"
 
+	"github.com/grafana/grafana/pkg/bus"
 	"github.com/grafana/grafana/pkg/log"
 	m "github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/plugins"
 	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/social"
 	"github.com/grafana/grafana/pkg/util"
 )
 
@@ -215,6 +218,44 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
 		if proxy.route != nil {
 			ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
 		}
+
+		if proxy.ds.JsonData != nil && proxy.ds.JsonData.Get("oauthPassThru").MustBool() {
+			provider := proxy.ds.JsonData.Get("oauthPassThruProvider").MustString()
+			connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
+			if !ok {
+				logger.Error("Failed to find oauth provider with given name", "provider", provider)
+			}
+			cmd := &m.GetAuthInfoQuery{UserId: proxy.ctx.UserId, AuthModule: provider}
+			if err := bus.Dispatch(cmd); err != nil {
+				logger.Error("Error feching oauth information for user", "error", err)
+			}
+
+			// TokenSource handles refreshing the token if it has expired
+			token, err := connect.TokenSource(proxy.ctx.Req.Context(), &oauth2.Token{
+				AccessToken:  cmd.Result.OAuthAccessToken,
+				Expiry:       cmd.Result.OAuthExpiry,
+				RefreshToken: cmd.Result.OAuthRefreshToken,
+				TokenType:    cmd.Result.OAuthTokenType,
+			}).Token()
+			if err != nil {
+				logger.Error("Failed to retrieve access token from oauth provider", "provider", cmd.Result.AuthModule)
+			}
+
+			// If the tokens are not the same, update the entry in the DB
+			if token.AccessToken != cmd.Result.OAuthAccessToken {
+				cmd2 := &m.UpdateAuthInfoCommand{
+					UserId:     cmd.Result.Id,
+					AuthModule: cmd.Result.AuthModule,
+					AuthId:     cmd.Result.AuthId,
+					OAuthToken: token,
+				}
+				if err := bus.Dispatch(cmd2); err != nil {
+					logger.Error("Failed to update access token during token refresh", "error", err)
+				}
+			}
+			req.Header.Del("Authorization")
+			req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
+		}
 	}
 }
 

+ 52 - 0
pkg/api/pluginproxy/ds_proxy_test.go

@@ -9,13 +9,16 @@ import (
 	"testing"
 	"time"
 
+	"golang.org/x/oauth2"
 	macaron "gopkg.in/macaron.v1"
 
+	"github.com/grafana/grafana/pkg/bus"
 	"github.com/grafana/grafana/pkg/components/simplejson"
 	"github.com/grafana/grafana/pkg/log"
 	m "github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/plugins"
 	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/social"
 	"github.com/grafana/grafana/pkg/util"
 	. "github.com/smartystreets/goconvey/convey"
 )
@@ -388,6 +391,55 @@ func TestDSRouteRule(t *testing.T) {
 				So(req.Header.Get("X-Canary"), ShouldEqual, "stillthere")
 			})
 		})
+
+		Convey("When proxying a datasource that has oauth token pass-thru enabled", func() {
+			social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{
+				SocialBase: &social.SocialBase{
+					Config: &oauth2.Config{},
+				},
+			}
+
+			bus.AddHandler("test", func(query *m.GetAuthInfoQuery) error {
+				query.Result = &m.UserAuth{
+					Id:                1,
+					UserId:            1,
+					AuthModule:        "generic_oauth",
+					OAuthAccessToken:  "testtoken",
+					OAuthRefreshToken: "testrefreshtoken",
+					OAuthTokenType:    "Bearer",
+					OAuthExpiry:       time.Now().AddDate(0, 0, 1),
+				}
+				return nil
+			})
+
+			plugin := &plugins.DataSourcePlugin{}
+			ds := &m.DataSource{
+				Type: "custom-datasource",
+				Url:  "http://host/root/",
+				JsonData: simplejson.NewFromAny(map[string]interface{}{
+					"oauthPassThru":         true,
+					"oauthPassThruProvider": "oauth_generic_oauth",
+				}),
+			}
+
+			req, _ := http.NewRequest("GET", "http://localhost/asd", nil)
+			ctx := &m.ReqContext{
+				SignedInUser: &m.SignedInUser{UserId: 1},
+				Context: &macaron.Context{
+					Req: macaron.Request{Request: req},
+				},
+			}
+			proxy := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/")
+			req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
+
+			So(err, ShouldBeNil)
+
+			proxy.getDirector()(req)
+
+			Convey("Should have access token in header", func() {
+				So(req.Header.Get("Authorization"), ShouldEqual, fmt.Sprintf("%s %s", "Bearer", "testtoken"))
+			})
+		})
 	})
 }
 

+ 22 - 1
pkg/login/ext_user.go

@@ -51,11 +51,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
 			return err
 		}
 
-		if extUser.AuthModule != "" && extUser.AuthId != "" {
+		if extUser.AuthModule != "" {
 			cmd2 := &m.SetAuthInfoCommand{
 				UserId:     cmd.Result.Id,
 				AuthModule: extUser.AuthModule,
 				AuthId:     extUser.AuthId,
+				OAuthToken: extUser.OAuthToken,
 			}
 			if err := bus.Dispatch(cmd2); err != nil {
 				return err
@@ -69,6 +70,14 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
 		if err != nil {
 			return err
 		}
+
+		// Always persist the latest token at log-in
+		if extUser.AuthModule != "" && extUser.OAuthToken != nil {
+			err = updateUserAuth(cmd.Result, extUser)
+			if err != nil {
+				return err
+			}
+		}
 	}
 
 	err = syncOrgRoles(cmd.Result, extUser)
@@ -143,6 +152,18 @@ func updateUser(user *m.User, extUser *m.ExternalUserInfo) error {
 	return bus.Dispatch(updateCmd)
 }
 
+func updateUserAuth(user *m.User, extUser *m.ExternalUserInfo) error {
+	updateCmd := &m.UpdateAuthInfoCommand{
+		AuthModule: extUser.AuthModule,
+		AuthId:     extUser.AuthId,
+		UserId:     user.Id,
+		OAuthToken: extUser.OAuthToken,
+	}
+
+	log.Debug("Updating user_auth info for user_id %d", user.Id)
+	return bus.Dispatch(updateCmd)
+}
+
 func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error {
 	// don't sync org roles if none are specified
 	if len(extUser.OrgRoles) == 0 {

+ 21 - 5
pkg/models/user_auth.go

@@ -2,17 +2,24 @@ package models
 
 import (
 	"time"
+
+	"golang.org/x/oauth2"
 )
 
 type UserAuth struct {
-	Id         int64
-	UserId     int64
-	AuthModule string
-	AuthId     string
-	Created    time.Time
+	Id                int64
+	UserId            int64
+	AuthModule        string
+	AuthId            string
+	Created           time.Time
+	OAuthAccessToken  string
+	OAuthRefreshToken string
+	OAuthTokenType    string
+	OAuthExpiry       time.Time
 }
 
 type ExternalUserInfo struct {
+	OAuthToken     *oauth2.Token
 	AuthModule     string
 	AuthId         string
 	UserId         int64
@@ -39,6 +46,14 @@ type SetAuthInfoCommand struct {
 	AuthModule string
 	AuthId     string
 	UserId     int64
+	OAuthToken *oauth2.Token
+}
+
+type UpdateAuthInfoCommand struct {
+	AuthModule string
+	AuthId     string
+	UserId     int64
+	OAuthToken *oauth2.Token
 }
 
 type DeleteAuthInfoCommand struct {
@@ -67,6 +82,7 @@ type GetUserByAuthInfoQuery struct {
 }
 
 type GetAuthInfoQuery struct {
+	UserId     int64
 	AuthModule string
 	AuthId     string
 

+ 1 - 0
pkg/services/sqlstore/migrations/migrations.go

@@ -33,6 +33,7 @@ func AddMigrations(mg *Migrator) {
 	addUserAuthMigrations(mg)
 	addServerlockMigrations(mg)
 	addUserAuthTokenMigrations(mg)
+	addUserAuthOAuthMigrations(mg)
 }
 
 func addMigrationLogMigrations(mg *Migrator) {

+ 25 - 0
pkg/services/sqlstore/migrations/user_auth_oauth_mig.go

@@ -0,0 +1,25 @@
+package migrations
+
+import . "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
+
+func addUserAuthOAuthMigrations(mg *Migrator) {
+	userAuthV2 := Table{Name: "user_auth"}
+
+	mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV2, &Column{
+		Name: "o_auth_access_token", Type: DB_Text, Nullable: true, Length: 255,
+	}))
+	mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV2, &Column{
+		Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true, Length: 255,
+	}))
+	mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV2, &Column{
+		Name: "o_auth_token_type", Type: DB_Text, Nullable: true, Length: 255,
+	}))
+	mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV2, &Column{
+		Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true,
+	}))
+
+	mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV2, &Index{
+		Cols: []string{"user_id"},
+	}))
+
+}

+ 85 - 1
pkg/services/sqlstore/user_auth.go

@@ -5,12 +5,15 @@ import (
 
 	"github.com/grafana/grafana/pkg/bus"
 	m "github.com/grafana/grafana/pkg/models"
+	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/util"
 )
 
 func init() {
 	bus.AddHandler("sql", GetUserByAuthInfo)
 	bus.AddHandler("sql", GetAuthInfo)
 	bus.AddHandler("sql", SetAuthInfo)
+	bus.AddHandler("sql", UpdateAuthInfo)
 	bus.AddHandler("sql", DeleteAuthInfo)
 }
 
@@ -94,7 +97,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
 	}
 
 	// create authInfo record to link accounts
-	if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" {
+	if authQuery.Result == nil && query.AuthModule != "" {
 		cmd2 := &m.SetAuthInfoCommand{
 			UserId:     user.Id,
 			AuthModule: query.AuthModule,
@@ -111,6 +114,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
 
 func GetAuthInfo(query *m.GetAuthInfoQuery) error {
 	userAuth := &m.UserAuth{
+		UserId:     query.UserId, // TODO this doesn't have an index in the db
 		AuthModule: query.AuthModule,
 		AuthId:     query.AuthId,
 	}
@@ -122,6 +126,28 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error {
 		return m.ErrUserNotFound
 	}
 
+	if userAuth.OAuthAccessToken != "" {
+		accessToken, err := util.Decrypt([]byte(userAuth.OAuthAccessToken), setting.SecretKey)
+		if err != nil {
+			return err
+		}
+		userAuth.OAuthAccessToken = string(accessToken)
+	}
+	if userAuth.OAuthRefreshToken != "" {
+		refreshToken, err := util.Decrypt([]byte(userAuth.OAuthRefreshToken), setting.SecretKey)
+		if err != nil {
+			return err
+		}
+		userAuth.OAuthRefreshToken = string(refreshToken)
+	}
+	if userAuth.OAuthTokenType != "" {
+		tokenType, err := util.Decrypt([]byte(userAuth.OAuthTokenType), setting.SecretKey)
+		if err != nil {
+			return err
+		}
+		userAuth.OAuthTokenType = string(tokenType)
+	}
+
 	query.Result = userAuth
 	return nil
 }
@@ -135,11 +161,69 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
 			Created:    time.Now(),
 		}
 
+		if cmd.OAuthToken != nil {
+			secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+			secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+			secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+
+			authUser.OAuthAccessToken = string(secretAccessToken)
+			authUser.OAuthRefreshToken = string(secretRefreshToken)
+			authUser.OAuthTokenType = string(secretTokenType)
+			authUser.OAuthExpiry = cmd.OAuthToken.Expiry
+		}
+
 		_, err := sess.Insert(authUser)
 		return err
 	})
 }
 
+func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
+	return inTransaction(func(sess *DBSession) error {
+		authUser := &m.UserAuth{
+			UserId:     cmd.UserId,
+			AuthModule: cmd.AuthModule,
+			AuthId:     cmd.AuthId,
+			Created:    time.Now(),
+		}
+
+		if cmd.OAuthToken != nil {
+			secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+			secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+			secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
+			if err != nil {
+				return err
+			}
+			authUser.OAuthAccessToken = string(secretAccessToken)
+			authUser.OAuthRefreshToken = string(secretRefreshToken)
+			authUser.OAuthTokenType = string(secretTokenType)
+			authUser.OAuthExpiry = cmd.OAuthToken.Expiry
+		}
+
+		cond := &m.UserAuth{
+			UserId:     cmd.UserId,
+			AuthModule: cmd.AuthModule,
+		}
+
+		_, err := sess.Update(authUser, cond)
+		return err
+	})
+}
+
 func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
 	return inTransaction(func(sess *DBSession) error {
 		_, err := sess.Delete(cmd.UserAuth)

+ 43 - 0
pkg/services/sqlstore/user_auth_test.go

@@ -4,8 +4,10 @@ import (
 	"context"
 	"fmt"
 	"testing"
+	"time"
 
 	. "github.com/smartystreets/goconvey/convey"
+	"golang.org/x/oauth2"
 
 	m "github.com/grafana/grafana/pkg/models"
 )
@@ -126,5 +128,46 @@ func TestUserAuth(t *testing.T) {
 			So(err, ShouldEqual, m.ErrUserNotFound)
 			So(query.Result, ShouldBeNil)
 		})
+
+		Convey("Can set & retrieve oauth token information", func() {
+			token := &oauth2.Token{
+				AccessToken:  "testaccess",
+				RefreshToken: "testrefresh",
+				Expiry:       time.Now(),
+				TokenType:    "Bearer",
+			}
+
+			// Find a user to set tokens on
+			login := "loginuser0"
+
+			// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
+			query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
+			err = GetUserByAuthInfo(query)
+
+			So(err, ShouldBeNil)
+			So(query.Result.Login, ShouldEqual, login)
+
+			cmd := &m.UpdateAuthInfoCommand{
+				UserId:     query.Result.Id,
+				AuthId:     query.AuthId,
+				AuthModule: query.AuthModule,
+				OAuthToken: token,
+			}
+			err = UpdateAuthInfo(cmd)
+
+			So(err, ShouldBeNil)
+
+			getAuthQuery := &m.GetAuthInfoQuery{
+				UserId: query.Result.Id,
+			}
+
+			err = GetAuthInfo(getAuthQuery)
+
+			So(err, ShouldBeNil)
+			So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
+			So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
+			So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
+
+		})
 	})
 }

+ 1 - 0
pkg/social/social.go

@@ -31,6 +31,7 @@ type SocialConnector interface {
 	AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
 	Exchange(ctx context.Context, code string) (*oauth2.Token, error)
 	Client(ctx context.Context, t *oauth2.Token) *http.Client
+	TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource
 }
 
 type SocialBase struct {

+ 13 - 0
public/app/features/datasources/partials/http_settings.html

@@ -87,6 +87,19 @@
 			<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="Skip TLS Verify" label-class="width-10"
 			  checked="current.jsonData.tlsSkipVerify" switch-class="max-width-6"></gf-form-checkbox>
 		</div>
+		<div class="gf-form-inline">
+			<gf-form-switch class="gf-form" ng-if="current.access=='proxy'" label="Forward OAuth Identity" label-class="width-13" tooltip="Forward the user's upstream OAuth identity to the datasource (Their access token gets passed along)." label-class="width-10" checked="current.jsonData.oauthPassThru" switch-class="max-width-6"></gf-form-switch>
+		</div>
+	</div>
+
+	<div class="gf-form-group" ng-if="current.jsonData.oauthPassThru">
+		<h6>OAuth Identity Forwarding Details</h6>
+		<div class="gf-form max-width-30">
+			<span class="gf-form-label width-10">OAuth Source</span>
+			<div class="gf-form-select-wrapper max-width-24">
+				<select class="gf-form-input" ng-model="current.jsonData.oauthPassThruProvider" ng-options="f.key as f.value for f in oauthProviders"></select>
+			</div>
+		</div>
 	</div>
 
 	<div class="gf-form-group" ng-if="current.basicAuth">

+ 7 - 0
public/app/features/datasources/settings/HttpSettingsCtrl.ts

@@ -20,6 +20,13 @@ coreModule.directive('datasourceHttpSettings', () => {
         $scope.getSuggestUrls = () => {
           return [$scope.suggestUrl];
         };
+        $scope.oauthProviders = [
+          { key: 'oauth_google', value: 'Google OAuth' },
+          { key: 'oauth_gitlab', value: 'GitLab OAuth' },
+          { key: 'oauth_generic_oauth', value: 'Generic OAuth' },
+          { key: 'oauth_grafana_com', value: 'Grafana OAuth' },
+          { key: 'oauth_github', value: 'GitHub OAuth' },
+        ];
       },
     },
   };