Browse Source

Batch disable users (#17254)

* batch disable users

* batch revoke users tokens

* split batch disable user and revoke token

* fix tests for batch disable users

* Chore: add BatchDisableUsers() to the bus
Alexander Zobnin 6 years ago
parent
commit
60ddad8fdb

+ 5 - 0
pkg/models/user.go

@@ -94,6 +94,11 @@ type DisableUserCommand struct {
 	IsDisabled bool
 }
 
+type BatchDisableUsersCommand struct {
+	UserIds    []int64
+	IsDisabled bool
+}
+
 type DeleteUserCommand struct {
 	UserId int64
 }

+ 31 - 0
pkg/services/auth/auth_token.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"crypto/sha256"
 	"encoding/hex"
+	"strings"
 	"time"
 
 	"github.com/grafana/grafana/pkg/infra/serverlock"
@@ -305,6 +306,36 @@ func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId i
 	})
 }
 
+func (s *UserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
+	return s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
+		if len(userIds) == 0 {
+			return nil
+		}
+
+		user_id_params := strings.Repeat(",?", len(userIds)-1)
+		sql := "DELETE from user_auth_token WHERE user_id IN (?" + user_id_params + ")"
+
+		params := []interface{}{sql}
+		for _, v := range userIds {
+			params = append(params, v)
+		}
+
+		res, err := dbSession.Exec(params...)
+		if err != nil {
+			return err
+		}
+
+		affected, err := res.RowsAffected()
+		if err != nil {
+			return err
+		}
+
+		s.log.Debug("all user tokens for given users revoked", "usersCount", len(userIds), "count", affected)
+
+		return err
+	})
+}
+
 func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
 
 	var result models.UserToken

+ 20 - 0
pkg/services/auth/auth_token_test.go

@@ -117,6 +117,26 @@ func TestUserAuthToken(t *testing.T) {
 					So(model2, ShouldBeNil)
 				})
 			})
+
+			Convey("When revoking users tokens in a batch", func() {
+				Convey("Can revoke all users tokens", func() {
+					userIds := []int64{}
+					for i := 0; i < 3; i++ {
+						userId := userID + int64(i+1)
+						userIds = append(userIds, userId)
+						userAuthTokenService.CreateToken(context.Background(), userId, "192.168.10.11:1234", "some user agent")
+					}
+
+					err := userAuthTokenService.BatchRevokeAllUserTokens(context.Background(), userIds)
+					So(err, ShouldBeNil)
+
+					for _, v := range userIds {
+						tokens, err := userAuthTokenService.GetUserTokens(context.Background(), v)
+						So(err, ShouldBeNil)
+						So(len(tokens), ShouldEqual, 0)
+					}
+				})
+			})
 		})
 
 		Convey("expires correctly", func() {

+ 26 - 0
pkg/services/sqlstore/user.go

@@ -28,6 +28,7 @@ func (ss *SqlStore) addUserQueryAndCommandHandlers() {
 	bus.AddHandler("sql", SearchUsers)
 	bus.AddHandler("sql", GetUserOrgList)
 	bus.AddHandler("sql", DisableUser)
+	bus.AddHandler("sql", BatchDisableUsers)
 	bus.AddHandler("sql", DeleteUser)
 	bus.AddHandler("sql", UpdateUserPermissions)
 	bus.AddHandler("sql", SetUserHelpFlag)
@@ -487,6 +488,31 @@ func DisableUser(cmd *m.DisableUserCommand) error {
 	return err
 }
 
+func BatchDisableUsers(cmd *m.BatchDisableUsersCommand) error {
+	return inTransaction(func(sess *DBSession) error {
+		userIds := cmd.UserIds
+
+		if len(userIds) == 0 {
+			return nil
+		}
+
+		user_id_params := strings.Repeat(",?", len(userIds)-1)
+		disableSQL := "UPDATE " + dialect.Quote("user") + " SET is_disabled=? WHERE Id IN (?" + user_id_params + ")"
+
+		disableParams := []interface{}{disableSQL, cmd.IsDisabled}
+		for _, v := range userIds {
+			disableParams = append(disableParams, v)
+		}
+
+		_, err := sess.Exec(disableParams...)
+		if err != nil {
+			return err
+		}
+
+		return nil
+	})
+}
+
 func DeleteUser(cmd *m.DeleteUserCommand) error {
 	return inTransaction(func(sess *DBSession) error {
 		return deleteUserInTransaction(sess, cmd)

+ 34 - 0
pkg/services/sqlstore/user_test.go

@@ -175,6 +175,40 @@ func TestUserDataAccess(t *testing.T) {
 					So(found, ShouldBeTrue)
 				})
 			})
+
+			Convey("When batch disabling users", func() {
+				userIdsToDisable := []int64{}
+				for i := 0; i < 3; i++ {
+					userIdsToDisable = append(userIdsToDisable, users[i].Id)
+				}
+				disableCmd := m.BatchDisableUsersCommand{UserIds: userIdsToDisable, IsDisabled: true}
+
+				err = BatchDisableUsers(&disableCmd)
+				So(err, ShouldBeNil)
+
+				Convey("Should disable all provided users", func() {
+					query := m.SearchUsersQuery{}
+					err = SearchUsers(&query)
+
+					So(query.Result.TotalCount, ShouldEqual, 5)
+					for _, user := range query.Result.Users {
+						shouldBeDisabled := false
+
+						// Check if user id is in the userIdsToDisable list
+						for _, disabledUserId := range userIdsToDisable {
+							if user.Id == disabledUserId {
+								So(user.IsDisabled, ShouldBeTrue)
+								shouldBeDisabled = true
+							}
+						}
+
+						// Otherwise user shouldn't be disabled
+						if !shouldBeDisabled {
+							So(user.IsDisabled, ShouldBeFalse)
+						}
+					}
+				})
+			})
 		})
 
 		Convey("Given one grafana admin user", func() {