浏览代码

Validation check for not removing the last account admin

Torkel Ödegaard 11 年之前
父节点
当前提交
eec178458b

+ 4 - 0
pkg/api/account_users.go

@@ -54,6 +54,10 @@ func RemoveAccountUser(c *middleware.Context) {
 	cmd := m.RemoveAccountUserCommand{AccountId: c.AccountId, UserId: userId}
 
 	if err := bus.Dispatch(&cmd); err != nil {
+		if err == m.ErrLastAccountAdmin {
+			c.JsonApiErr(400, "Cannot remove last account admin", nil)
+			return
+		}
 		c.JsonApiErr(500, "Failed to remove user from account", err)
 	}
 

+ 2 - 1
pkg/models/account_user.go

@@ -7,7 +7,8 @@ import (
 
 // Typed errors
 var (
-	ErrInvalidRoleType = errors.New("Invalid role type")
+	ErrInvalidRoleType  = errors.New("Invalid role type")
+	ErrLastAccountAdmin = errors.New("Cannot remove last account admin")
 )
 
 type RoleType string

+ 6 - 0
pkg/services/sqlstore/account_test.go

@@ -103,6 +103,12 @@ func TestAccountDataAccess(t *testing.T) {
 						So(query.Result.AccountRole, ShouldEqual, "Viewer")
 					})
 				})
+
+				Convey("Cannot delete last admin account user", func() {
+					cmd := m.RemoveAccountUserCommand{AccountId: ac1.AccountId, UserId: ac1.Id}
+					err := RemoveAccountUser(&cmd)
+					So(err, ShouldEqual, m.ErrLastAccountAdmin)
+				})
 			})
 		})
 	})

+ 14 - 0
pkg/services/sqlstore/account_users.go

@@ -47,6 +47,20 @@ func RemoveAccountUser(cmd *m.RemoveAccountUserCommand) error {
 	return inTransaction(func(sess *xorm.Session) error {
 		var rawSql = "DELETE FROM account_user WHERE account_id=? and user_id=?"
 		_, err := sess.Exec(rawSql, cmd.AccountId, cmd.UserId)
+		if err != nil {
+			return err
+		}
+
+		// validate that there is an admin user left
+		res, err := sess.Query("SELECT 1 from account_user WHERE account_id=? and role='Admin'", cmd.AccountId)
+		if err != nil {
+			return err
+		}
+
+		if len(res) == 0 {
+			return m.ErrLastAccountAdmin
+		}
+
 		return err
 	})
 }

+ 2 - 2
pkg/services/sqlstore/migrations_test.go

@@ -20,8 +20,8 @@ func TestMigrations(t *testing.T) {
 
 	testDBs := []sqlutil.TestDB{
 		sqlutil.TestDB_Sqlite3,
-		sqlutil.TestDB_Mysql,
-		sqlutil.TestDB_Postgres,
+		//	sqlutil.TestDB_Mysql,
+		//		sqlutil.TestDB_Postgres,
 	}
 
 	for _, testDB := range testDBs {