فهرست منبع

Merge pull request #15051 from ellisvlad/13711_parse_database_config_ipv6_host

Parse database host correctly when using IPv6
Marcus Efraimsson 7 سال پیش
والد
کامیت
ed6cca61c9
5فایلهای تغییر یافته به همراه182 افزوده شده و 15 حذف شده
  1. 4 7
      pkg/services/sqlstore/sqlstore.go
  2. 101 0
      pkg/services/sqlstore/sqlstore_test.go
  3. 10 8
      pkg/tsdb/mssql/mssql.go
  4. 24 0
      pkg/util/ip.go
  5. 43 0
      pkg/util/ip_test.go

+ 4 - 7
pkg/services/sqlstore/sqlstore.go

@@ -21,6 +21,7 @@ import (
 	"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
 	"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
 	"github.com/grafana/grafana/pkg/services/sqlstore/sqlutil"
 	"github.com/grafana/grafana/pkg/services/sqlstore/sqlutil"
 	"github.com/grafana/grafana/pkg/setting"
 	"github.com/grafana/grafana/pkg/setting"
+	"github.com/grafana/grafana/pkg/util"
 
 
 	"github.com/go-sql-driver/mysql"
 	"github.com/go-sql-driver/mysql"
 	"github.com/go-xorm/xorm"
 	"github.com/go-xorm/xorm"
@@ -222,13 +223,9 @@ func (ss *SqlStore) buildConnectionString() (string, error) {
 			cnnstr += "&tls=custom"
 			cnnstr += "&tls=custom"
 		}
 		}
 	case migrator.POSTGRES:
 	case migrator.POSTGRES:
-		var host, port = "127.0.0.1", "5432"
-		fields := strings.Split(ss.dbCfg.Host, ":")
-		if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 {
-			host = fields[0]
-		}
-		if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
-			port = fields[1]
+		host, port, err := util.SplitIpPort(ss.dbCfg.Host, "5432")
+		if err != nil {
+			return "", err
 		}
 		}
 		if ss.dbCfg.Pwd == "" {
 		if ss.dbCfg.Pwd == "" {
 			ss.dbCfg.Pwd = "''"
 			ss.dbCfg.Pwd = "''"

+ 101 - 0
pkg/services/sqlstore/sqlstore_test.go

@@ -0,0 +1,101 @@
+package sqlstore
+
+import (
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+
+	"github.com/grafana/grafana/pkg/setting"
+)
+
+type sqlStoreTest struct {
+	name          string
+	dbType        string
+	dbHost        string
+	connStrValues []string
+}
+
+var sqlStoreTestCases = []sqlStoreTest{
+	{
+		name:          "MySQL IPv4",
+		dbType:        "mysql",
+		dbHost:        "1.2.3.4:5678",
+		connStrValues: []string{"tcp(1.2.3.4:5678)"},
+	},
+	{
+		name:          "Postgres IPv4",
+		dbType:        "postgres",
+		dbHost:        "1.2.3.4:5678",
+		connStrValues: []string{"host=1.2.3.4", "port=5678"},
+	},
+	{
+		name:          "Postgres IPv4 (Default Port)",
+		dbType:        "postgres",
+		dbHost:        "1.2.3.4",
+		connStrValues: []string{"host=1.2.3.4", "port=5432"},
+	},
+	{
+		name:          "MySQL IPv4 (Default Port)",
+		dbType:        "mysql",
+		dbHost:        "1.2.3.4",
+		connStrValues: []string{"tcp(1.2.3.4)"},
+	},
+	{
+		name:          "MySQL IPv6",
+		dbType:        "mysql",
+		dbHost:        "[fe80::24e8:31b2:91df:b177]:1234",
+		connStrValues: []string{"tcp([fe80::24e8:31b2:91df:b177]:1234)"},
+	},
+	{
+		name:          "Postgres IPv6",
+		dbType:        "postgres",
+		dbHost:        "[fe80::24e8:31b2:91df:b177]:1234",
+		connStrValues: []string{"host=fe80::24e8:31b2:91df:b177", "port=1234"},
+	},
+	{
+		name:          "MySQL IPv6 (Default Port)",
+		dbType:        "mysql",
+		dbHost:        "::1",
+		connStrValues: []string{"tcp(::1)"},
+	},
+	{
+		name:          "Postgres IPv6 (Default Port)",
+		dbType:        "postgres",
+		dbHost:        "::1",
+		connStrValues: []string{"host=::1", "port=5432"},
+	},
+}
+
+func TestSqlConnectionString(t *testing.T) {
+	Convey("Testing SQL Connection Strings", t, func() {
+		t.Helper()
+
+		for _, testCase := range sqlStoreTestCases {
+			Convey(testCase.name, func() {
+				sqlstore := &SqlStore{}
+				sqlstore.Cfg = makeSqlStoreTestConfig(testCase.dbType, testCase.dbHost)
+				sqlstore.readConfig()
+
+				connStr, err := sqlstore.buildConnectionString()
+
+				So(err, ShouldBeNil)
+				for _, connSubStr := range testCase.connStrValues {
+					So(connStr, ShouldContainSubstring, connSubStr)
+				}
+			})
+		}
+	})
+}
+
+func makeSqlStoreTestConfig(dbType string, host string) *setting.Cfg {
+	cfg := setting.NewCfg()
+
+	sec, _ := cfg.Raw.NewSection("database")
+	sec.NewKey("type", dbType)
+	sec.NewKey("host", host)
+	sec.NewKey("user", "user")
+	sec.NewKey("name", "test_db")
+	sec.NewKey("password", "pass")
+
+	return cfg
+}

+ 10 - 8
pkg/tsdb/mssql/mssql.go

@@ -4,13 +4,13 @@ import (
 	"database/sql"
 	"database/sql"
 	"fmt"
 	"fmt"
 	"strconv"
 	"strconv"
-	"strings"
 
 
 	_ "github.com/denisenkom/go-mssqldb"
 	_ "github.com/denisenkom/go-mssqldb"
 	"github.com/go-xorm/core"
 	"github.com/go-xorm/core"
 	"github.com/grafana/grafana/pkg/log"
 	"github.com/grafana/grafana/pkg/log"
 	"github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/models"
 	"github.com/grafana/grafana/pkg/tsdb"
 	"github.com/grafana/grafana/pkg/tsdb"
+	"github.com/grafana/grafana/pkg/util"
 )
 )
 
 
 func init() {
 func init() {
@@ -20,7 +20,10 @@ func init() {
 func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) {
 func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) {
 	logger := log.New("tsdb.mssql")
 	logger := log.New("tsdb.mssql")
 
 
-	cnnstr := generateConnectionString(datasource)
+	cnnstr, err := generateConnectionString(datasource)
+	if err != nil {
+		return nil, err
+	}
 	logger.Debug("getEngine", "connection", cnnstr)
 	logger.Debug("getEngine", "connection", cnnstr)
 
 
 	config := tsdb.SqlQueryEndpointConfiguration{
 	config := tsdb.SqlQueryEndpointConfiguration{
@@ -37,7 +40,7 @@ func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin
 	return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger)
 	return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger)
 }
 }
 
 
-func generateConnectionString(datasource *models.DataSource) string {
+func generateConnectionString(datasource *models.DataSource) (string, error) {
 	password := ""
 	password := ""
 	for key, value := range datasource.SecureJsonData.Decrypt() {
 	for key, value := range datasource.SecureJsonData.Decrypt() {
 		if key == "password" {
 		if key == "password" {
@@ -46,12 +49,11 @@ func generateConnectionString(datasource *models.DataSource) string {
 		}
 		}
 	}
 	}
 
 
-	hostParts := strings.Split(datasource.Url, ":")
-	if len(hostParts) < 2 {
-		hostParts = append(hostParts, "1433")
+	server, port, err := util.SplitIpPort(datasource.Url, "1433")
+	if err != nil {
+		return "", err
 	}
 	}
 
 
-	server, port := hostParts[0], hostParts[1]
 	encrypt := datasource.JsonData.Get("encrypt").MustString("false")
 	encrypt := datasource.JsonData.Get("encrypt").MustString("false")
 	connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;",
 	connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;",
 		server,
 		server,
@@ -63,7 +65,7 @@ func generateConnectionString(datasource *models.DataSource) string {
 	if encrypt != "false" {
 	if encrypt != "false" {
 		connStr += fmt.Sprintf("encrypt=%s;", encrypt)
 		connStr += fmt.Sprintf("encrypt=%s;", encrypt)
 	}
 	}
-	return connStr
+	return connStr, nil
 }
 }
 
 
 type mssqlRowTransformer struct {
 type mssqlRowTransformer struct {

+ 24 - 0
pkg/util/ip.go

@@ -0,0 +1,24 @@
+package util
+
+import (
+	"net"
+)
+
+func SplitIpPort(ipStr string, portDefault string) (ip string, port string, err error) {
+	ipAddr := net.ParseIP(ipStr)
+
+	if ipAddr == nil {
+		// Port was included
+		ip, port, err = net.SplitHostPort(ipStr)
+
+		if err != nil {
+			return "", "", err
+		}
+	} else {
+		// No port was included
+		ip = ipAddr.String()
+		port = portDefault
+	}
+
+	return ip, port, nil
+}

+ 43 - 0
pkg/util/ip_test.go

@@ -0,0 +1,43 @@
+package util
+
+import (
+	"testing"
+
+	. "github.com/smartystreets/goconvey/convey"
+)
+
+func TestSplitIpPort(t *testing.T) {
+
+	Convey("When parsing an IPv4 without explicit port", t, func() {
+		ip, port, err := SplitIpPort("1.2.3.4", "5678")
+
+		So(err, ShouldEqual, nil)
+		So(ip, ShouldEqual, "1.2.3.4")
+		So(port, ShouldEqual, "5678")
+	})
+
+	Convey("When parsing an IPv6 without explicit port", t, func() {
+		ip, port, err := SplitIpPort("::1", "5678")
+
+		So(err, ShouldEqual, nil)
+		So(ip, ShouldEqual, "::1")
+		So(port, ShouldEqual, "5678")
+	})
+
+	Convey("When parsing an IPv4 with explicit port", t, func() {
+		ip, port, err := SplitIpPort("1.2.3.4:56", "78")
+
+		So(err, ShouldEqual, nil)
+		So(ip, ShouldEqual, "1.2.3.4")
+		So(port, ShouldEqual, "56")
+	})
+
+	Convey("When parsing an IPv6 with explicit port", t, func() {
+		ip, port, err := SplitIpPort("[::1]:56", "78")
+
+		So(err, ShouldEqual, nil)
+		So(ip, ShouldEqual, "::1")
+		So(port, ShouldEqual, "56")
+	})
+
+}