Browse Source

generalized database connection cert support and added to postgres

Tom Kozlowski 9 years ago
parent
commit
c21ffcc6c9
2 changed files with 16 additions and 22 deletions
  1. 15 21
      pkg/services/sqlstore/sqlstore.go
  2. 1 1
      pkg/services/sqlstore/tls_mysql.go

+ 15 - 21
pkg/services/sqlstore/sqlstore.go

@@ -23,12 +23,13 @@ import (
 	_ "github.com/mattn/go-sqlite3"
 )
 
-type MySQLConfig struct {
-	SslMode        string
-	CaCertPath     string
-	ClientKeyPath  string
-	ClientCertPath string
-	ServerCertName string
+
+type DatabaseConfig struct {
+  Type, Host, Name, User, Pwd, Path, SslMode string
+  CaCertPath     string
+  ClientKeyPath  string
+  ClientCertPath string
+  ServerCertName string
 }
 
 var (
@@ -37,11 +38,8 @@ var (
 
 	HasEngine bool
 
-	DbCfg struct {
-		Type, Host, Name, User, Pwd, Path, SslMode string
-	}
+  DbCfg DatabaseConfig
 
-	mysqlConfig MySQLConfig
 	UseSQLite3  bool
 	sqlog       log.Logger = log.New("sqlstore")
 )
@@ -118,8 +116,8 @@ func getEngine() (*xorm.Engine, error) {
 		cnnstr = fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8",
 			DbCfg.User, DbCfg.Pwd, protocol, DbCfg.Host, DbCfg.Name)
 
-		if mysqlConfig.SslMode == "true" || mysqlConfig.SslMode == "skip-verify" {
-			tlsCert, err := makeCert("custom", mysqlConfig)
+		if DbCfg.SslMode == "true" || DbCfg.SslMode == "skip-verify" {
+			tlsCert, err := makeCert("custom", DbCfg)
 			if err != nil {
 				return nil, err
 			}
@@ -141,7 +139,7 @@ func getEngine() (*xorm.Engine, error) {
 		if DbCfg.User == "" {
 			DbCfg.User = "''"
 		}
-		cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode)
+		cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s sslcert=%s sslkey=%s sslrootcert=%s", DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode, DbCfg.ClientCertPath, DbCfg.ClientKeyPath, DbCfg.CaCertPath)
 	case "sqlite3":
 		if !filepath.IsAbs(DbCfg.Path) {
 			DbCfg.Path = filepath.Join(setting.DataPath, DbCfg.Path)
@@ -189,13 +187,9 @@ func LoadConfig() {
 		UseSQLite3 = true
 	}
 	DbCfg.SslMode = sec.Key("ssl_mode").String()
+  DbCfg.CaCertPath = sec.Key("ca_cert_path").String()
+  DbCfg.ClientKeyPath = sec.Key("client_key_path").String()
+  DbCfg.ClientCertPath = sec.Key("client_cert_path").String()
+  DbCfg.ServerCertName = sec.Key("server_cert_name").String()
 	DbCfg.Path = sec.Key("path").MustString("data/grafana.db")
-
-	if DbCfg.Type == "mysql" {
-		mysqlConfig.SslMode = DbCfg.SslMode
-		mysqlConfig.CaCertPath = sec.Key("ca_cert_path").String()
-		mysqlConfig.ClientKeyPath = sec.Key("client_key_path").String()
-		mysqlConfig.ClientCertPath = sec.Key("client_cert_path").String()
-		mysqlConfig.ServerCertName = sec.Key("server_cert_name").String()
-	}
 }

+ 1 - 1
pkg/services/sqlstore/tls_mysql.go

@@ -7,7 +7,7 @@ import (
 	"io/ioutil"
 )
 
-func makeCert(tlsPoolName string, config MySQLConfig) (*tls.Config, error) {
+func makeCert(tlsPoolName string, config DatabaseConfig) (*tls.Config, error) {
 	rootCertPool := x509.NewCertPool()
 	pem, err := ioutil.ReadFile(config.CaCertPath)
 	if err != nil {