ssl_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package pq
  2. // This file contains SSL tests
  3. import (
  4. _ "crypto/sha256"
  5. "crypto/x509"
  6. "database/sql"
  7. "fmt"
  8. "os"
  9. "path/filepath"
  10. "testing"
  11. )
  12. func maybeSkipSSLTests(t *testing.T) {
  13. // Require some special variables for testing certificates
  14. if os.Getenv("PQSSLCERTTEST_PATH") == "" {
  15. t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests")
  16. }
  17. value := os.Getenv("PQGOSSLTESTS")
  18. if value == "" || value == "0" {
  19. t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests")
  20. } else if value != "1" {
  21. t.Fatalf("unexpected value %q for PQGOSSLTESTS", value)
  22. }
  23. }
  24. func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) {
  25. db, err := openTestConnConninfo(conninfo)
  26. if err != nil {
  27. // should never fail
  28. t.Fatal(err)
  29. }
  30. // Do something with the connection to see whether it's working or not.
  31. tx, err := db.Begin()
  32. if err == nil {
  33. return db, tx.Rollback()
  34. }
  35. _ = db.Close()
  36. return nil, err
  37. }
  38. func checkSSLSetup(t *testing.T, conninfo string) {
  39. db, err := openSSLConn(t, conninfo)
  40. if err == nil {
  41. db.Close()
  42. t.Fatalf("expected error with conninfo=%q", conninfo)
  43. }
  44. }
  45. // Connect over SSL and run a simple query to test the basics
  46. func TestSSLConnection(t *testing.T) {
  47. maybeSkipSSLTests(t)
  48. // Environment sanity check: should fail without SSL
  49. checkSSLSetup(t, "sslmode=disable user=pqgossltest")
  50. db, err := openSSLConn(t, "sslmode=require user=pqgossltest")
  51. if err != nil {
  52. t.Fatal(err)
  53. }
  54. rows, err := db.Query("SELECT 1")
  55. if err != nil {
  56. t.Fatal(err)
  57. }
  58. rows.Close()
  59. }
  60. // Test sslmode=verify-full
  61. func TestSSLVerifyFull(t *testing.T) {
  62. maybeSkipSSLTests(t)
  63. // Environment sanity check: should fail without SSL
  64. checkSSLSetup(t, "sslmode=disable user=pqgossltest")
  65. // Not OK according to the system CA
  66. _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest")
  67. if err == nil {
  68. t.Fatal("expected error")
  69. }
  70. _, ok := err.(x509.UnknownAuthorityError)
  71. if !ok {
  72. t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err)
  73. }
  74. rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
  75. rootCert := "sslrootcert=" + rootCertPath + " "
  76. // No match on Common Name
  77. _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest")
  78. if err == nil {
  79. t.Fatal("expected error")
  80. }
  81. _, ok = err.(x509.HostnameError)
  82. if !ok {
  83. t.Fatalf("expected x509.HostnameError, got %#+v", err)
  84. }
  85. // OK
  86. _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest")
  87. if err != nil {
  88. t.Fatal(err)
  89. }
  90. }
  91. // Test sslmode=verify-ca
  92. func TestSSLVerifyCA(t *testing.T) {
  93. maybeSkipSSLTests(t)
  94. // Environment sanity check: should fail without SSL
  95. checkSSLSetup(t, "sslmode=disable user=pqgossltest")
  96. // Not OK according to the system CA
  97. _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest")
  98. if err == nil {
  99. t.Fatal("expected error")
  100. }
  101. _, ok := err.(x509.UnknownAuthorityError)
  102. if !ok {
  103. t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err)
  104. }
  105. rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
  106. rootCert := "sslrootcert=" + rootCertPath + " "
  107. // No match on Common Name, but that's OK
  108. _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest")
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. // Everything OK
  113. _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest")
  114. if err != nil {
  115. t.Fatal(err)
  116. }
  117. }
  118. func getCertConninfo(t *testing.T, source string) string {
  119. var sslkey string
  120. var sslcert string
  121. certpath := os.Getenv("PQSSLCERTTEST_PATH")
  122. switch source {
  123. case "missingkey":
  124. sslkey = "/tmp/filedoesnotexist"
  125. sslcert = filepath.Join(certpath, "postgresql.crt")
  126. case "missingcert":
  127. sslkey = filepath.Join(certpath, "postgresql.key")
  128. sslcert = "/tmp/filedoesnotexist"
  129. case "certtwice":
  130. sslkey = filepath.Join(certpath, "postgresql.crt")
  131. sslcert = filepath.Join(certpath, "postgresql.crt")
  132. case "valid":
  133. sslkey = filepath.Join(certpath, "postgresql.key")
  134. sslcert = filepath.Join(certpath, "postgresql.crt")
  135. default:
  136. t.Fatalf("invalid source %q", source)
  137. }
  138. return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert)
  139. }
  140. // Authenticate over SSL using client certificates
  141. func TestSSLClientCertificates(t *testing.T) {
  142. maybeSkipSSLTests(t)
  143. // Environment sanity check: should fail without SSL
  144. checkSSLSetup(t, "sslmode=disable user=pqgossltest")
  145. // Should also fail without a valid certificate
  146. db, err := openSSLConn(t, "sslmode=require user=pqgosslcert")
  147. if err == nil {
  148. db.Close()
  149. t.Fatal("expected error")
  150. }
  151. pge, ok := err.(*Error)
  152. if !ok {
  153. t.Fatal("expected pq.Error")
  154. }
  155. if pge.Code.Name() != "invalid_authorization_specification" {
  156. t.Fatalf("unexpected error code %q", pge.Code.Name())
  157. }
  158. // Should work
  159. db, err = openSSLConn(t, getCertConninfo(t, "valid"))
  160. if err != nil {
  161. t.Fatal(err)
  162. }
  163. rows, err := db.Query("SELECT 1")
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. rows.Close()
  168. }
  169. // Test errors with ssl certificates
  170. func TestSSLClientCertificatesMissingFiles(t *testing.T) {
  171. maybeSkipSSLTests(t)
  172. // Environment sanity check: should fail without SSL
  173. checkSSLSetup(t, "sslmode=disable user=pqgossltest")
  174. // Key missing, should fail
  175. _, err := openSSLConn(t, getCertConninfo(t, "missingkey"))
  176. if err == nil {
  177. t.Fatal("expected error")
  178. }
  179. // should be a PathError
  180. _, ok := err.(*os.PathError)
  181. if !ok {
  182. t.Fatalf("expected PathError, got %#+v", err)
  183. }
  184. // Cert missing, should fail
  185. _, err = openSSLConn(t, getCertConninfo(t, "missingcert"))
  186. if err == nil {
  187. t.Fatal("expected error")
  188. }
  189. // should be a PathError
  190. _, ok = err.(*os.PathError)
  191. if !ok {
  192. t.Fatalf("expected PathError, got %#+v", err)
  193. }
  194. // Key has wrong permissions, should fail
  195. _, err = openSSLConn(t, getCertConninfo(t, "certtwice"))
  196. if err == nil {
  197. t.Fatal("expected error")
  198. }
  199. if err != ErrSSLKeyHasWorldPermissions {
  200. t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err)
  201. }
  202. }