Explorar o código

extract parsing of datasource tls config to method

Marcus Efraimsson %!s(int64=7) %!d(string=hai) anos
pai
achega
f157c19e16
Modificáronse 2 ficheiros con 35 adicións e 56 borrados
  1. 30 18
      pkg/models/datasource_cache.go
  2. 5 38
      pkg/tsdb/mysql/mysql.go

+ 30 - 18
pkg/models/datasource_cache.go

@@ -46,19 +46,16 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
 		return t.Transport, nil
 		return t.Transport, nil
 	}
 	}
 
 
-	var tlsSkipVerify, tlsClientAuth, tlsAuthWithCACert bool
-	if ds.JsonData != nil {
-		tlsClientAuth = ds.JsonData.Get("tlsAuth").MustBool(false)
-		tlsAuthWithCACert = ds.JsonData.Get("tlsAuthWithCACert").MustBool(false)
-		tlsSkipVerify = ds.JsonData.Get("tlsSkipVerify").MustBool(false)
+	tlsConfig, err := ds.GetTLSConfig()
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
+	tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient
+
 	transport := &http.Transport{
 	transport := &http.Transport{
-		TLSClientConfig: &tls.Config{
-			InsecureSkipVerify: tlsSkipVerify,
-			Renegotiation:      tls.RenegotiateFreelyAsClient,
-		},
-		Proxy: http.ProxyFromEnvironment,
+		TLSClientConfig: tlsConfig,
+		Proxy:           http.ProxyFromEnvironment,
 		Dial: (&net.Dialer{
 		Dial: (&net.Dialer{
 			Timeout:   30 * time.Second,
 			Timeout:   30 * time.Second,
 			KeepAlive: 30 * time.Second,
 			KeepAlive: 30 * time.Second,
@@ -70,6 +67,26 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
 		IdleConnTimeout:       90 * time.Second,
 		IdleConnTimeout:       90 * time.Second,
 	}
 	}
 
 
+	ptc.cache[ds.Id] = cachedTransport{
+		Transport: transport,
+		updated:   ds.Updated,
+	}
+
+	return transport, nil
+}
+
+func (ds *DataSource) GetTLSConfig() (*tls.Config, error) {
+	var tlsSkipVerify, tlsClientAuth, tlsAuthWithCACert bool
+	if ds.JsonData != nil {
+		tlsClientAuth = ds.JsonData.Get("tlsAuth").MustBool(false)
+		tlsAuthWithCACert = ds.JsonData.Get("tlsAuthWithCACert").MustBool(false)
+		tlsSkipVerify = ds.JsonData.Get("tlsSkipVerify").MustBool(false)
+	}
+
+	tlsConfig := &tls.Config{
+		InsecureSkipVerify: tlsSkipVerify,
+	}
+
 	if tlsClientAuth || tlsAuthWithCACert {
 	if tlsClientAuth || tlsAuthWithCACert {
 		decrypted := ds.SecureJsonData.Decrypt()
 		decrypted := ds.SecureJsonData.Decrypt()
 		if tlsAuthWithCACert && len(decrypted["tlsCACert"]) > 0 {
 		if tlsAuthWithCACert && len(decrypted["tlsCACert"]) > 0 {
@@ -78,7 +95,7 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
 			if !ok {
 			if !ok {
 				return nil, errors.New("Failed to parse TLS CA PEM certificate")
 				return nil, errors.New("Failed to parse TLS CA PEM certificate")
 			}
 			}
-			transport.TLSClientConfig.RootCAs = caPool
+			tlsConfig.RootCAs = caPool
 		}
 		}
 
 
 		if tlsClientAuth {
 		if tlsClientAuth {
@@ -86,14 +103,9 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-			transport.TLSClientConfig.Certificates = []tls.Certificate{cert}
+			tlsConfig.Certificates = []tls.Certificate{cert}
 		}
 		}
 	}
 	}
 
 
-	ptc.cache[ds.Id] = cachedTransport{
-		Transport: transport,
-		updated:   ds.Updated,
-	}
-
-	return transport, nil
+	return tlsConfig, nil
 }
 }

+ 5 - 38
pkg/tsdb/mysql/mysql.go

@@ -2,15 +2,11 @@ package mysql
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
-	"errors"
 	"fmt"
 	"fmt"
 	"reflect"
 	"reflect"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 
 
-	"crypto/tls"
-	"crypto/x509"
-
 	"github.com/go-sql-driver/mysql"
 	"github.com/go-sql-driver/mysql"
 	"github.com/go-xorm/core"
 	"github.com/go-xorm/core"
 	"github.com/grafana/grafana/pkg/log"
 	"github.com/grafana/grafana/pkg/log"
@@ -37,42 +33,13 @@ func newMysqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin
 		datasource.Database,
 		datasource.Database,
 	)
 	)
 
 
-	var tlsSkipVerify, tlsAuth, tlsAuthWithCACert bool
-	if datasource.JsonData != nil {
-		tlsAuth = datasource.JsonData.Get("tlsAuth").MustBool(false)
-		tlsAuthWithCACert = datasource.JsonData.Get("tlsAuthWithCACert").MustBool(false)
-		tlsSkipVerify = datasource.JsonData.Get("tlsSkipVerify").MustBool(false)
+	tlsConfig, err := datasource.GetTLSConfig()
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
-	if tlsAuth || tlsAuthWithCACert {
-
-		secureJsonData := datasource.SecureJsonData.Decrypt()
-		tlsConfig := tls.Config{
-			InsecureSkipVerify: tlsSkipVerify,
-		}
-
-		if tlsAuthWithCACert && len(secureJsonData["tlsCACert"]) > 0 {
-
-			caPool := x509.NewCertPool()
-			if ok := caPool.AppendCertsFromPEM([]byte(secureJsonData["tlsCACert"])); !ok {
-				return nil, errors.New("Failed to parse TLS CA PEM certificate")
-			}
-
-			tlsConfig.RootCAs = caPool
-		}
-
-		if tlsAuth {
-			certs, err := tls.X509KeyPair([]byte(secureJsonData["tlsClientCert"]), []byte(secureJsonData["tlsClientKey"]))
-			if err != nil {
-				return nil, err
-			}
-			clientCert := make([]tls.Certificate, 0, 1)
-			clientCert = append(clientCert, certs)
-
-			tlsConfig.Certificates = clientCert
-		}
-
-		mysql.RegisterTLSConfig(datasource.Name, &tlsConfig)
+	if tlsConfig.RootCAs != nil || len(tlsConfig.Certificates) > 0 {
+		mysql.RegisterTLSConfig(datasource.Name, tlsConfig)
 		cnnstr += "&tls=" + datasource.Name
 		cnnstr += "&tls=" + datasource.Name
 	}
 	}