Browse Source

sql: added code migration type

Torkel Ödegaard 7 years ago
parent
commit
92ed1f04af

+ 7 - 1
pkg/api/login.go

@@ -78,7 +78,13 @@ func tryLoginUsingRememberCookie(c *m.ReqContext) bool {
 	user := userQuery.Result
 
 	// validate remember me cookie
-	if val, _ := c.GetSuperSecureCookie(user.Rands+user.Password, setting.CookieRememberName); val != user.Login {
+	signingKey := user.Rands + user.Password
+	if len(signingKey) < 10 {
+		c.Logger.Error("Invalid user signingKey")
+		return false
+	}
+
+	if val, _ := c.GetSuperSecureCookie(signingKey, setting.CookieRememberName); val != user.Login {
 		return false
 	}
 

+ 40 - 1
pkg/services/sqlstore/migrations/user_mig.go

@@ -1,6 +1,12 @@
 package migrations
 
-import . "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
+import (
+	"fmt"
+
+	"github.com/go-xorm/xorm"
+	. "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
+	"github.com/grafana/grafana/pkg/util"
+)
 
 func addUserMigrations(mg *Migrator) {
 	userV1 := Table{
@@ -107,4 +113,37 @@ func addUserMigrations(mg *Migrator) {
 	mg.AddMigration("Add last_seen_at column to user", NewAddColumnMigration(userV2, &Column{
 		Name: "last_seen_at", Type: DB_DateTime, Nullable: true,
 	}))
+
+	// Adds salt & rands for old users who used ldap or oauth
+	mg.AddMigration("Add missing user data", &AddMissingUserSaltAndRandsMigration{})
+}
+
+type AddMissingUserSaltAndRandsMigration struct {
+	MigrationBase
+}
+
+func (m *AddMissingUserSaltAndRandsMigration) Sql(dialect Dialect) string {
+	return "code migration"
+}
+
+type TempUserDTO struct {
+	Id    int64
+	Login string
+}
+
+func (m *AddMissingUserSaltAndRandsMigration) Exec(sess *xorm.Session, mg *Migrator) error {
+	users := make([]*TempUserDTO, 0)
+
+	err := sess.Sql(fmt.Sprintf("SELECT id, login from %s WHERE rands = ''", mg.Dialect.Quote("user"))).Find(&users)
+	if err != nil {
+		return err
+	}
+
+	for _, user := range users {
+		_, err := sess.Exec("UPDATE "+mg.Dialect.Quote("user")+" SET salt = ?, rands = ? WHERE id = ?", util.GetRandomString(10), util.GetRandomString(10), user.Id)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
 }

+ 11 - 5
pkg/services/sqlstore/migrator/migrator.go

@@ -12,7 +12,7 @@ import (
 
 type Migrator struct {
 	x          *xorm.Engine
-	dialect    Dialect
+	Dialect    Dialect
 	migrations []Migration
 	Logger     log.Logger
 }
@@ -31,7 +31,7 @@ func NewMigrator(engine *xorm.Engine) *Migrator {
 	mg.x = engine
 	mg.Logger = log.New("migrator")
 	mg.migrations = make([]Migration, 0)
-	mg.dialect = NewDialect(mg.x)
+	mg.Dialect = NewDialect(mg.x)
 	return mg
 }
 
@@ -86,7 +86,7 @@ func (mg *Migrator) Start() error {
 			continue
 		}
 
-		sql := m.Sql(mg.dialect)
+		sql := m.Sql(mg.Dialect)
 
 		record := MigrationLog{
 			MigrationId: m.Id(),
@@ -122,7 +122,7 @@ func (mg *Migrator) exec(m Migration, sess *xorm.Session) error {
 
 	condition := m.GetCondition()
 	if condition != nil {
-		sql, args := condition.Sql(mg.dialect)
+		sql, args := condition.Sql(mg.Dialect)
 		results, err := sess.SQL(sql).Query(args...)
 		if err != nil || len(results) == 0 {
 			mg.Logger.Debug("Skipping migration condition not fulfilled", "id", m.Id())
@@ -130,7 +130,13 @@ func (mg *Migrator) exec(m Migration, sess *xorm.Session) error {
 		}
 	}
 
-	_, err := sess.Exec(m.Sql(mg.dialect))
+	var err error
+	if codeMigration, ok := m.(CodeMigration); ok {
+		err = codeMigration.Exec(sess, mg)
+	} else {
+		_, err = sess.Exec(m.Sql(mg.Dialect))
+	}
+
 	if err != nil {
 		mg.Logger.Error("Executing migration failed", "id", m.Id(), "error", err)
 		return err

+ 7 - 0
pkg/services/sqlstore/migrator/types.go

@@ -3,6 +3,8 @@ package migrator
 import (
 	"fmt"
 	"strings"
+
+	"github.com/go-xorm/xorm"
 )
 
 const (
@@ -19,6 +21,11 @@ type Migration interface {
 	GetCondition() MigrationCondition
 }
 
+type CodeMigration interface {
+	Migration
+	Exec(sess *xorm.Session, migrator *Migrator) error
+}
+
 type SQLType string
 
 type ColumnType string

+ 3 - 2
pkg/services/sqlstore/user.go

@@ -113,9 +113,10 @@ func CreateUser(ctx context.Context, cmd *m.CreateUserCommand) error {
 			LastSeenAt:    time.Now().AddDate(-10, 0, 0),
 		}
 
+		user.Salt = util.GetRandomString(10)
+		user.Rands = util.GetRandomString(10)
+
 		if len(cmd.Password) > 0 {
-			user.Salt = util.GetRandomString(10)
-			user.Rands = util.GetRandomString(10)
 			user.Password = util.EncodePassword(cmd.Password, user.Salt)
 		}
 

+ 22 - 0
pkg/services/sqlstore/user_test.go

@@ -15,6 +15,28 @@ func TestUserDataAccess(t *testing.T) {
 	Convey("Testing DB", t, func() {
 		InitTestDB(t)
 
+		Convey("Creating a user", func() {
+			cmd := &m.CreateUserCommand{
+				Email: "usertest@test.com",
+				Name:  "user name",
+				Login: "user_test_login",
+			}
+
+			err := CreateUser(context.Background(), cmd)
+			So(err, ShouldBeNil)
+
+			Convey("Loading a user", func() {
+				query := m.GetUserByIdQuery{Id: cmd.Result.Id}
+				err := GetUserById(&query)
+				So(err, ShouldBeNil)
+
+				So(query.Result.Email, ShouldEqual, "usertest@test.com")
+				So(query.Result.Password, ShouldEqual, "")
+				So(query.Result.Rands, ShouldHaveLength, 10)
+				So(query.Result.Salt, ShouldHaveLength, 10)
+			})
+		})
+
 		Convey("Given 5 users", func() {
 			var err error
 			var cmd *m.CreateUserCommand