|
|
@@ -9,12 +9,14 @@
|
|
|
package mysql
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"crypto/tls"
|
|
|
"database/sql"
|
|
|
"database/sql/driver"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "log"
|
|
|
"net"
|
|
|
"net/url"
|
|
|
"os"
|
|
|
@@ -74,23 +76,75 @@ type DBTest struct {
|
|
|
db *sql.DB
|
|
|
}
|
|
|
|
|
|
+func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
|
|
|
+ if !available {
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
+ }
|
|
|
+
|
|
|
+ dsn += "&multiStatements=true"
|
|
|
+ var db *sql.DB
|
|
|
+ if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
|
|
|
+ db, err = sql.Open("mysql", dsn)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ dbt := &DBTest{t, db}
|
|
|
+ for _, test := range tests {
|
|
|
+ test(dbt)
|
|
|
+ dbt.db.Exec("DROP TABLE IF EXISTS test")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
|
|
|
if !available {
|
|
|
- t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
}
|
|
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error connecting: %s", err.Error())
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
}
|
|
|
defer db.Close()
|
|
|
|
|
|
db.Exec("DROP TABLE IF EXISTS test")
|
|
|
|
|
|
+ dsn2 := dsn + "&interpolateParams=true"
|
|
|
+ var db2 *sql.DB
|
|
|
+ if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
|
|
|
+ db2, err = sql.Open("mysql", dsn2)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db2.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ dsn3 := dsn + "&multiStatements=true"
|
|
|
+ var db3 *sql.DB
|
|
|
+ if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
|
|
|
+ db3, err = sql.Open("mysql", dsn3)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db3.Close()
|
|
|
+ }
|
|
|
+
|
|
|
dbt := &DBTest{t, db}
|
|
|
+ dbt2 := &DBTest{t, db2}
|
|
|
+ dbt3 := &DBTest{t, db3}
|
|
|
for _, test := range tests {
|
|
|
test(dbt)
|
|
|
dbt.db.Exec("DROP TABLE IF EXISTS test")
|
|
|
+ if db2 != nil {
|
|
|
+ test(dbt2)
|
|
|
+ dbt2.db.Exec("DROP TABLE IF EXISTS test")
|
|
|
+ }
|
|
|
+ if db3 != nil {
|
|
|
+ test(dbt3)
|
|
|
+ dbt3.db.Exec("DROP TABLE IF EXISTS test")
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -98,13 +152,13 @@ func (dbt *DBTest) fail(method, query string, err error) {
|
|
|
if len(query) > 300 {
|
|
|
query = "[query too large to print]"
|
|
|
}
|
|
|
- dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
|
|
|
+ dbt.Fatalf("error on %s %s: %s", method, query, err.Error())
|
|
|
}
|
|
|
|
|
|
func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
|
|
|
res, err := dbt.db.Exec(query, args...)
|
|
|
if err != nil {
|
|
|
- dbt.fail("Exec", query, err)
|
|
|
+ dbt.fail("exec", query, err)
|
|
|
}
|
|
|
return res
|
|
|
}
|
|
|
@@ -112,7 +166,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result)
|
|
|
func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
|
|
|
rows, err := dbt.db.Query(query, args...)
|
|
|
if err != nil {
|
|
|
- dbt.fail("Query", query, err)
|
|
|
+ dbt.fail("query", query, err)
|
|
|
}
|
|
|
return rows
|
|
|
}
|
|
|
@@ -123,7 +177,7 @@ func TestEmptyQuery(t *testing.T) {
|
|
|
rows := dbt.mustQuery("--")
|
|
|
// will hang before #255
|
|
|
if rows.Next() {
|
|
|
- dbt.Errorf("Next on rows must be false")
|
|
|
+ dbt.Errorf("next on rows must be false")
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
@@ -147,7 +201,7 @@ func TestCRUD(t *testing.T) {
|
|
|
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
}
|
|
|
if count != 1 {
|
|
|
- dbt.Fatalf("Expected 1 affected row, got %d", count)
|
|
|
+ dbt.Fatalf("expected 1 affected row, got %d", count)
|
|
|
}
|
|
|
|
|
|
id, err := res.LastInsertId()
|
|
|
@@ -155,7 +209,7 @@ func TestCRUD(t *testing.T) {
|
|
|
dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
|
|
|
}
|
|
|
if id != 0 {
|
|
|
- dbt.Fatalf("Expected InsertID 0, got %d", id)
|
|
|
+ dbt.Fatalf("expected InsertId 0, got %d", id)
|
|
|
}
|
|
|
|
|
|
// Read
|
|
|
@@ -180,7 +234,7 @@ func TestCRUD(t *testing.T) {
|
|
|
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
}
|
|
|
if count != 1 {
|
|
|
- dbt.Fatalf("Expected 1 affected row, got %d", count)
|
|
|
+ dbt.Fatalf("expected 1 affected row, got %d", count)
|
|
|
}
|
|
|
|
|
|
// Check Update
|
|
|
@@ -205,7 +259,7 @@ func TestCRUD(t *testing.T) {
|
|
|
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
}
|
|
|
if count != 1 {
|
|
|
- dbt.Fatalf("Expected 1 affected row, got %d", count)
|
|
|
+ dbt.Fatalf("expected 1 affected row, got %d", count)
|
|
|
}
|
|
|
|
|
|
// Check for unexpected rows
|
|
|
@@ -215,8 +269,52 @@ func TestCRUD(t *testing.T) {
|
|
|
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
}
|
|
|
if count != 0 {
|
|
|
- dbt.Fatalf("Expected 0 affected row, got %d", count)
|
|
|
+ dbt.Fatalf("expected 0 affected row, got %d", count)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestMultiQuery(t *testing.T) {
|
|
|
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
|
|
|
+ // Create Table
|
|
|
+ dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
|
|
|
+
|
|
|
+ // Create Data
|
|
|
+ res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
|
|
|
+ count, err := res.RowsAffected()
|
|
|
+ if err != nil {
|
|
|
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
+ }
|
|
|
+ if count != 1 {
|
|
|
+ dbt.Fatalf("expected 1 affected row, got %d", count)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update
|
|
|
+ res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
|
|
|
+ count, err = res.RowsAffected()
|
|
|
+ if err != nil {
|
|
|
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
|
|
|
+ }
|
|
|
+ if count != 1 {
|
|
|
+ dbt.Fatalf("expected 1 affected row, got %d", count)
|
|
|
}
|
|
|
+
|
|
|
+ // Read
|
|
|
+ var out int
|
|
|
+ rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
|
|
|
+ if rows.Next() {
|
|
|
+ rows.Scan(&out)
|
|
|
+ if 5 != out {
|
|
|
+ dbt.Errorf("5 != %t", out)
|
|
|
+ }
|
|
|
+
|
|
|
+ if rows.Next() {
|
|
|
+ dbt.Error("unexpected data")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ dbt.Error("no data")
|
|
|
+ }
|
|
|
+
|
|
|
})
|
|
|
}
|
|
|
|
|
|
@@ -636,14 +734,14 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if nb.Valid {
|
|
|
- dbt.Error("Valid NullBool which should be invalid")
|
|
|
+ dbt.Error("valid NullBool which should be invalid")
|
|
|
}
|
|
|
// Valid
|
|
|
if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if !nb.Valid {
|
|
|
- dbt.Error("Invalid NullBool which should be valid")
|
|
|
+ dbt.Error("invalid NullBool which should be valid")
|
|
|
} else if nb.Bool != true {
|
|
|
dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
|
|
|
}
|
|
|
@@ -655,16 +753,16 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if nf.Valid {
|
|
|
- dbt.Error("Valid NullFloat64 which should be invalid")
|
|
|
+ dbt.Error("valid NullFloat64 which should be invalid")
|
|
|
}
|
|
|
// Valid
|
|
|
if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if !nf.Valid {
|
|
|
- dbt.Error("Invalid NullFloat64 which should be valid")
|
|
|
+ dbt.Error("invalid NullFloat64 which should be valid")
|
|
|
} else if nf.Float64 != float64(1) {
|
|
|
- dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
|
|
|
+ dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
|
|
|
}
|
|
|
|
|
|
// NullInt64
|
|
|
@@ -674,16 +772,16 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if ni.Valid {
|
|
|
- dbt.Error("Valid NullInt64 which should be invalid")
|
|
|
+ dbt.Error("valid NullInt64 which should be invalid")
|
|
|
}
|
|
|
// Valid
|
|
|
if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if !ni.Valid {
|
|
|
- dbt.Error("Invalid NullInt64 which should be valid")
|
|
|
+ dbt.Error("invalid NullInt64 which should be valid")
|
|
|
} else if ni.Int64 != int64(1) {
|
|
|
- dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64)
|
|
|
+ dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
|
|
|
}
|
|
|
|
|
|
// NullString
|
|
|
@@ -693,16 +791,16 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if ns.Valid {
|
|
|
- dbt.Error("Valid NullString which should be invalid")
|
|
|
+ dbt.Error("valid NullString which should be invalid")
|
|
|
}
|
|
|
// Valid
|
|
|
if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if !ns.Valid {
|
|
|
- dbt.Error("Invalid NullString which should be valid")
|
|
|
+ dbt.Error("invalid NullString which should be valid")
|
|
|
} else if ns.String != `1` {
|
|
|
- dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
|
|
|
+ dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
|
|
|
}
|
|
|
|
|
|
// nil-bytes
|
|
|
@@ -712,14 +810,14 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if b != nil {
|
|
|
- dbt.Error("Non-nil []byte wich should be nil")
|
|
|
+ dbt.Error("non-nil []byte wich should be nil")
|
|
|
}
|
|
|
// Read non-nil
|
|
|
if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if b == nil {
|
|
|
- dbt.Error("Nil []byte wich should be non-nil")
|
|
|
+ dbt.Error("nil []byte wich should be non-nil")
|
|
|
}
|
|
|
// Insert nil
|
|
|
b = nil
|
|
|
@@ -728,7 +826,7 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if !success {
|
|
|
- dbt.Error("Inserting []byte(nil) as NULL failed")
|
|
|
+ dbt.Error("inserting []byte(nil) as NULL failed")
|
|
|
}
|
|
|
// Check input==output with input==nil
|
|
|
b = nil
|
|
|
@@ -736,7 +834,7 @@ func TestNULL(t *testing.T) {
|
|
|
dbt.Fatal(err)
|
|
|
}
|
|
|
if b != nil {
|
|
|
- dbt.Error("Non-nil echo from nil input")
|
|
|
+ dbt.Error("non-nil echo from nil input")
|
|
|
}
|
|
|
// Check input==output with input!=nil
|
|
|
b = []byte("")
|
|
|
@@ -765,6 +863,49 @@ func TestNULL(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func TestUint64(t *testing.T) {
|
|
|
+ const (
|
|
|
+ u0 = uint64(0)
|
|
|
+ uall = ^u0
|
|
|
+ uhigh = uall >> 1
|
|
|
+ utop = ^uhigh
|
|
|
+ s0 = int64(0)
|
|
|
+ sall = ^s0
|
|
|
+ shigh = int64(uhigh)
|
|
|
+ stop = ^shigh
|
|
|
+ )
|
|
|
+ runTests(t, dsn, func(dbt *DBTest) {
|
|
|
+ stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
|
|
|
+ if err != nil {
|
|
|
+ dbt.Fatal(err)
|
|
|
+ }
|
|
|
+ defer stmt.Close()
|
|
|
+ row := stmt.QueryRow(
|
|
|
+ u0, uhigh, utop, uall,
|
|
|
+ s0, shigh, stop, sall,
|
|
|
+ )
|
|
|
+
|
|
|
+ var ua, ub, uc, ud uint64
|
|
|
+ var sa, sb, sc, sd int64
|
|
|
+
|
|
|
+ err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
|
|
|
+ if err != nil {
|
|
|
+ dbt.Fatal(err)
|
|
|
+ }
|
|
|
+ switch {
|
|
|
+ case ua != u0,
|
|
|
+ ub != uhigh,
|
|
|
+ uc != utop,
|
|
|
+ ud != uall,
|
|
|
+ sa != s0,
|
|
|
+ sb != shigh,
|
|
|
+ sc != stop,
|
|
|
+ sd != sall:
|
|
|
+ dbt.Fatal("unexpected result value")
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
func TestLongData(t *testing.T) {
|
|
|
runTests(t, dsn, func(dbt *DBTest) {
|
|
|
var maxAllowedPacketSize int
|
|
|
@@ -855,7 +996,7 @@ func TestLoadData(t *testing.T) {
|
|
|
dbt.Fatalf("%d != %d", i, id)
|
|
|
}
|
|
|
if values[i-1] != value {
|
|
|
- dbt.Fatalf("%s != %s", values[i-1], value)
|
|
|
+ dbt.Fatalf("%q != %q", values[i-1], value)
|
|
|
}
|
|
|
}
|
|
|
err = rows.Err()
|
|
|
@@ -864,7 +1005,7 @@ func TestLoadData(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if i != 4 {
|
|
|
- dbt.Fatalf("Rows count mismatch. Got %d, want 4", i)
|
|
|
+ dbt.Fatalf("rows count mismatch. Got %d, want 4", i)
|
|
|
}
|
|
|
}
|
|
|
file, err := ioutil.TempFile("", "gotest")
|
|
|
@@ -880,13 +1021,13 @@ func TestLoadData(t *testing.T) {
|
|
|
|
|
|
// Local File
|
|
|
RegisterLocalFile(file.Name())
|
|
|
- dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
|
|
|
+ dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
|
|
|
verifyLoadDataResult()
|
|
|
// negative test
|
|
|
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
|
|
|
if err == nil {
|
|
|
- dbt.Fatal("Load non-existent file didn't fail")
|
|
|
- } else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" {
|
|
|
+ dbt.Fatal("load non-existent file didn't fail")
|
|
|
+ } else if err.Error() != "local file 'doesnotexist' is not registered" {
|
|
|
dbt.Fatal(err.Error())
|
|
|
}
|
|
|
|
|
|
@@ -906,7 +1047,7 @@ func TestLoadData(t *testing.T) {
|
|
|
// negative test
|
|
|
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
|
|
|
if err == nil {
|
|
|
- dbt.Fatal("Load non-existent Reader didn't fail")
|
|
|
+ dbt.Fatal("load non-existent Reader didn't fail")
|
|
|
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
|
|
|
dbt.Fatal(err.Error())
|
|
|
}
|
|
|
@@ -960,7 +1101,7 @@ func TestFoundRows(t *testing.T) {
|
|
|
|
|
|
func TestStrict(t *testing.T) {
|
|
|
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
|
|
|
- relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
|
|
|
+ relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
|
|
|
// make sure the MySQL version is recent enough with a separate connection
|
|
|
// before running the test
|
|
|
conn, err := MySQLDriver{}.Open(relaxedDsn)
|
|
|
@@ -986,7 +1127,7 @@ func TestStrict(t *testing.T) {
|
|
|
|
|
|
var checkWarnings = func(err error, mode string, idx int) {
|
|
|
if err == nil {
|
|
|
- dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
|
|
|
+ dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in)
|
|
|
}
|
|
|
|
|
|
if warnings, ok := err.(MySQLWarnings); ok {
|
|
|
@@ -995,18 +1136,18 @@ func TestStrict(t *testing.T) {
|
|
|
codes[i] = warnings[i].Code
|
|
|
}
|
|
|
if len(codes) != len(queries[idx].codes) {
|
|
|
- dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
|
|
+ dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
|
|
}
|
|
|
|
|
|
for i := range warnings {
|
|
|
if codes[i] != queries[idx].codes[i] {
|
|
|
- dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
|
|
+ dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
- dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
|
|
|
+ dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1022,7 +1163,7 @@ func TestStrict(t *testing.T) {
|
|
|
for i := range queries {
|
|
|
stmt, err = dbt.db.Prepare(queries[i].in)
|
|
|
if err != nil {
|
|
|
- dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error())
|
|
|
+ dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error())
|
|
|
}
|
|
|
|
|
|
_, err = stmt.Exec()
|
|
|
@@ -1030,7 +1171,7 @@ func TestStrict(t *testing.T) {
|
|
|
|
|
|
err = stmt.Close()
|
|
|
if err != nil {
|
|
|
- dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error())
|
|
|
+ dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error())
|
|
|
}
|
|
|
}
|
|
|
})
|
|
|
@@ -1040,9 +1181,9 @@ func TestTLS(t *testing.T) {
|
|
|
tlsTest := func(dbt *DBTest) {
|
|
|
if err := dbt.db.Ping(); err != nil {
|
|
|
if err == ErrNoTLS {
|
|
|
- dbt.Skip("Server does not support TLS")
|
|
|
+ dbt.Skip("server does not support TLS")
|
|
|
} else {
|
|
|
- dbt.Fatalf("Error on Ping: %s", err.Error())
|
|
|
+ dbt.Fatalf("error on Ping: %s", err.Error())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1055,7 +1196,7 @@ func TestTLS(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if value == nil {
|
|
|
- dbt.Fatal("No Cipher")
|
|
|
+ dbt.Fatal("no Cipher")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1072,42 +1213,42 @@ func TestTLS(t *testing.T) {
|
|
|
func TestReuseClosedConnection(t *testing.T) {
|
|
|
// this test does not use sql.database, it uses the driver directly
|
|
|
if !available {
|
|
|
- t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
}
|
|
|
|
|
|
md := &MySQLDriver{}
|
|
|
conn, err := md.Open(dsn)
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error connecting: %s", err.Error())
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
}
|
|
|
stmt, err := conn.Prepare("DO 1")
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error preparing statement: %s", err.Error())
|
|
|
+ t.Fatalf("error preparing statement: %s", err.Error())
|
|
|
}
|
|
|
_, err = stmt.Exec(nil)
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error executing statement: %s", err.Error())
|
|
|
+ t.Fatalf("error executing statement: %s", err.Error())
|
|
|
}
|
|
|
err = conn.Close()
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error closing connection: %s", err.Error())
|
|
|
+ t.Fatalf("error closing connection: %s", err.Error())
|
|
|
}
|
|
|
|
|
|
defer func() {
|
|
|
if err := recover(); err != nil {
|
|
|
- t.Errorf("Panic after reusing a closed connection: %v", err)
|
|
|
+ t.Errorf("panic after reusing a closed connection: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
_, err = stmt.Exec(nil)
|
|
|
if err != nil && err != driver.ErrBadConn {
|
|
|
- t.Errorf("Unexpected error '%s', expected '%s'",
|
|
|
+ t.Errorf("unexpected error '%s', expected '%s'",
|
|
|
err.Error(), driver.ErrBadConn.Error())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestCharset(t *testing.T) {
|
|
|
if !available {
|
|
|
- t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
}
|
|
|
|
|
|
mustSetCharset := func(charsetParam, expected string) {
|
|
|
@@ -1116,14 +1257,14 @@ func TestCharset(t *testing.T) {
|
|
|
defer rows.Close()
|
|
|
|
|
|
if !rows.Next() {
|
|
|
- dbt.Fatalf("Error getting connection charset: %s", rows.Err())
|
|
|
+ dbt.Fatalf("error getting connection charset: %s", rows.Err())
|
|
|
}
|
|
|
|
|
|
var got string
|
|
|
rows.Scan(&got)
|
|
|
|
|
|
if got != expected {
|
|
|
- dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
|
|
|
+ dbt.Fatalf("expected connection charset %s but got %s", expected, got)
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
@@ -1145,14 +1286,14 @@ func TestFailingCharset(t *testing.T) {
|
|
|
_, err := dbt.db.Exec("SELECT 1")
|
|
|
if err == nil {
|
|
|
dbt.db.Close()
|
|
|
- t.Fatalf("Connection must not succeed without a valid charset")
|
|
|
+ t.Fatalf("connection must not succeed without a valid charset")
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
|
|
|
func TestCollation(t *testing.T) {
|
|
|
if !available {
|
|
|
- t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
}
|
|
|
|
|
|
defaultCollation := "utf8_general_ci"
|
|
|
@@ -1182,12 +1323,36 @@ func TestCollation(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if got != expected {
|
|
|
- dbt.Fatalf("Expected connection collation %s but got %s", expected, got)
|
|
|
+ dbt.Fatalf("expected connection collation %s but got %s", expected, got)
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestColumnsWithAlias(t *testing.T) {
|
|
|
+ runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) {
|
|
|
+ rows := dbt.mustQuery("SELECT 1 AS A")
|
|
|
+ defer rows.Close()
|
|
|
+ cols, _ := rows.Columns()
|
|
|
+ if len(cols) != 1 {
|
|
|
+ t.Fatalf("expected 1 column, got %d", len(cols))
|
|
|
+ }
|
|
|
+ if cols[0] != "A" {
|
|
|
+ t.Fatalf("expected column name \"A\", got \"%s\"", cols[0])
|
|
|
+ }
|
|
|
+ rows.Close()
|
|
|
+
|
|
|
+ rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A")
|
|
|
+ cols, _ = rows.Columns()
|
|
|
+ if len(cols) != 1 {
|
|
|
+ t.Fatalf("expected 1 column, got %d", len(cols))
|
|
|
+ }
|
|
|
+ if cols[0] != "A.one" {
|
|
|
+ t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0])
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
func TestRawBytesResultExceedsBuffer(t *testing.T) {
|
|
|
runTests(t, dsn, func(dbt *DBTest) {
|
|
|
// defaultBufSize from buffer.go
|
|
|
@@ -1223,7 +1388,7 @@ func TestTimezoneConversion(t *testing.T) {
|
|
|
// Retrieve time from DB
|
|
|
rows := dbt.mustQuery("SELECT ts FROM test")
|
|
|
if !rows.Next() {
|
|
|
- dbt.Fatal("Didn't get any rows out")
|
|
|
+ dbt.Fatal("did not get any rows out")
|
|
|
}
|
|
|
|
|
|
var dbTime time.Time
|
|
|
@@ -1234,7 +1399,7 @@ func TestTimezoneConversion(t *testing.T) {
|
|
|
|
|
|
// Check that dates match
|
|
|
if reftime.Unix() != dbTime.Unix() {
|
|
|
- dbt.Errorf("Times don't match.\n")
|
|
|
+ dbt.Errorf("times do not match.\n")
|
|
|
dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
|
|
|
dbt.Errorf(" Now(UTC)=%v\n", dbTime)
|
|
|
}
|
|
|
@@ -1260,7 +1425,7 @@ func TestRowsClose(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if rows.Next() {
|
|
|
- dbt.Fatal("Unexpected row after rows.Close()")
|
|
|
+ dbt.Fatal("unexpected row after rows.Close()")
|
|
|
}
|
|
|
|
|
|
err = rows.Err()
|
|
|
@@ -1292,7 +1457,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if !rows.Next() {
|
|
|
- dbt.Fatal("Getting row failed")
|
|
|
+ dbt.Fatal("getting row failed")
|
|
|
} else {
|
|
|
err = rows.Err()
|
|
|
if err != nil {
|
|
|
@@ -1302,7 +1467,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
|
|
|
var out bool
|
|
|
err = rows.Scan(&out)
|
|
|
if err != nil {
|
|
|
- dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
|
|
|
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
|
|
|
}
|
|
|
if out != true {
|
|
|
dbt.Errorf("true != %t", out)
|
|
|
@@ -1338,7 +1503,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
// 1
|
|
|
if !rows1.Next() {
|
|
|
- dbt.Fatal("1st rows1.Next failed")
|
|
|
+ dbt.Fatal("first rows1.Next failed")
|
|
|
} else {
|
|
|
err = rows1.Err()
|
|
|
if err != nil {
|
|
|
@@ -1347,7 +1512,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
err = rows1.Scan(&out)
|
|
|
if err != nil {
|
|
|
- dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
|
|
|
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
|
|
|
}
|
|
|
if out != true {
|
|
|
dbt.Errorf("true != %t", out)
|
|
|
@@ -1355,7 +1520,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if !rows2.Next() {
|
|
|
- dbt.Fatal("1st rows2.Next failed")
|
|
|
+ dbt.Fatal("first rows2.Next failed")
|
|
|
} else {
|
|
|
err = rows2.Err()
|
|
|
if err != nil {
|
|
|
@@ -1364,7 +1529,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
err = rows2.Scan(&out)
|
|
|
if err != nil {
|
|
|
- dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
|
|
|
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
|
|
|
}
|
|
|
if out != true {
|
|
|
dbt.Errorf("true != %t", out)
|
|
|
@@ -1373,7 +1538,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
// 2
|
|
|
if !rows1.Next() {
|
|
|
- dbt.Fatal("2nd rows1.Next failed")
|
|
|
+ dbt.Fatal("second rows1.Next failed")
|
|
|
} else {
|
|
|
err = rows1.Err()
|
|
|
if err != nil {
|
|
|
@@ -1382,14 +1547,14 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
err = rows1.Scan(&out)
|
|
|
if err != nil {
|
|
|
- dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
|
|
|
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
|
|
|
}
|
|
|
if out != false {
|
|
|
dbt.Errorf("false != %t", out)
|
|
|
}
|
|
|
|
|
|
if rows1.Next() {
|
|
|
- dbt.Fatal("Unexpected row on rows1")
|
|
|
+ dbt.Fatal("unexpected row on rows1")
|
|
|
}
|
|
|
err = rows1.Close()
|
|
|
if err != nil {
|
|
|
@@ -1398,7 +1563,7 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
if !rows2.Next() {
|
|
|
- dbt.Fatal("2nd rows2.Next failed")
|
|
|
+ dbt.Fatal("second rows2.Next failed")
|
|
|
} else {
|
|
|
err = rows2.Err()
|
|
|
if err != nil {
|
|
|
@@ -1407,14 +1572,14 @@ func TestStmtMultiRows(t *testing.T) {
|
|
|
|
|
|
err = rows2.Scan(&out)
|
|
|
if err != nil {
|
|
|
- dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
|
|
|
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
|
|
|
}
|
|
|
if out != false {
|
|
|
dbt.Errorf("false != %t", out)
|
|
|
}
|
|
|
|
|
|
if rows2.Next() {
|
|
|
- dbt.Fatal("Unexpected row on rows2")
|
|
|
+ dbt.Fatal("unexpected row on rows2")
|
|
|
}
|
|
|
err = rows2.Close()
|
|
|
if err != nil {
|
|
|
@@ -1459,7 +1624,7 @@ func TestConcurrent(t *testing.T) {
|
|
|
if err != nil {
|
|
|
dbt.Fatalf("%s", err.Error())
|
|
|
}
|
|
|
- dbt.Logf("Testing up to %d concurrent connections \r\n", max)
|
|
|
+ dbt.Logf("testing up to %d concurrent connections \r\n", max)
|
|
|
|
|
|
var remaining, succeeded int32 = int32(max), 0
|
|
|
|
|
|
@@ -1483,7 +1648,7 @@ func TestConcurrent(t *testing.T) {
|
|
|
|
|
|
if err != nil {
|
|
|
if err.Error() != "Error 1040: Too many connections" {
|
|
|
- fatalf("Error on Conn %d: %s", id, err.Error())
|
|
|
+ fatalf("error on conn %d: %s", id, err.Error())
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
@@ -1491,13 +1656,13 @@ func TestConcurrent(t *testing.T) {
|
|
|
// keep the connection busy until all connections are open
|
|
|
for remaining > 0 {
|
|
|
if _, err = tx.Exec("DO 1"); err != nil {
|
|
|
- fatalf("Error on Conn %d: %s", id, err.Error())
|
|
|
+ fatalf("error on conn %d: %s", id, err.Error())
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
|
- fatalf("Error on Conn %d: %s", id, err.Error())
|
|
|
+ fatalf("error on conn %d: %s", id, err.Error())
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -1513,14 +1678,14 @@ func TestConcurrent(t *testing.T) {
|
|
|
dbt.Fatal(fatalError)
|
|
|
}
|
|
|
|
|
|
- dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
|
|
|
+ dbt.Logf("reached %d concurrent connections\r\n", succeeded)
|
|
|
})
|
|
|
}
|
|
|
|
|
|
// Tests custom dial functions
|
|
|
func TestCustomDial(t *testing.T) {
|
|
|
if !available {
|
|
|
- t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
}
|
|
|
|
|
|
// our custom dial function which justs wraps net.Dial here
|
|
|
@@ -1530,11 +1695,117 @@ func TestCustomDial(t *testing.T) {
|
|
|
|
|
|
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Error connecting: %s", err.Error())
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
}
|
|
|
defer db.Close()
|
|
|
|
|
|
if _, err = db.Exec("DO 1"); err != nil {
|
|
|
- t.Fatalf("Connection failed: %s", err.Error())
|
|
|
+ t.Fatalf("connection failed: %s", err.Error())
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestSQLInjection(t *testing.T) {
|
|
|
+ createTest := func(arg string) func(dbt *DBTest) {
|
|
|
+ return func(dbt *DBTest) {
|
|
|
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
|
|
+ dbt.mustExec("INSERT INTO test VALUES (?)", 1)
|
|
|
+
|
|
|
+ var v int
|
|
|
+ // NULL can't be equal to anything, the idea here is to inject query so it returns row
|
|
|
+ // This test verifies that escapeQuotes and escapeBackslash are working properly
|
|
|
+ err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v)
|
|
|
+ if err == sql.ErrNoRows {
|
|
|
+ return // success, sql injection failed
|
|
|
+ } else if err == nil {
|
|
|
+ dbt.Errorf("sql injection successful with arg: %s", arg)
|
|
|
+ } else {
|
|
|
+ dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ dsns := []string{
|
|
|
+ dsn,
|
|
|
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
|
|
+ }
|
|
|
+ for _, testdsn := range dsns {
|
|
|
+ runTests(t, testdsn, createTest("1 OR 1=1"))
|
|
|
+ runTests(t, testdsn, createTest("' OR '1'='1"))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Test if inserted data is correctly retrieved after being escaped
|
|
|
+func TestInsertRetrieveEscapedData(t *testing.T) {
|
|
|
+ testData := func(dbt *DBTest) {
|
|
|
+ dbt.mustExec("CREATE TABLE test (v VARCHAR(255))")
|
|
|
+
|
|
|
+ // All sequences that are escaped by escapeQuotes and escapeBackslash
|
|
|
+ v := "foo \x00\n\r\x1a\"'\\"
|
|
|
+ dbt.mustExec("INSERT INTO test VALUES (?)", v)
|
|
|
+
|
|
|
+ var out string
|
|
|
+ err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out)
|
|
|
+ if err != nil {
|
|
|
+ dbt.Fatalf("%s", err.Error())
|
|
|
+ }
|
|
|
+
|
|
|
+ if out != v {
|
|
|
+ dbt.Errorf("%q != %q", out, v)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ dsns := []string{
|
|
|
+ dsn,
|
|
|
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
|
|
+ }
|
|
|
+ for _, testdsn := range dsns {
|
|
|
+ runTests(t, testdsn, testData)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestUnixSocketAuthFail(t *testing.T) {
|
|
|
+ runTests(t, dsn, func(dbt *DBTest) {
|
|
|
+ // Save the current logger so we can restore it.
|
|
|
+ oldLogger := errLog
|
|
|
+
|
|
|
+ // Set a new logger so we can capture its output.
|
|
|
+ buffer := bytes.NewBuffer(make([]byte, 0, 64))
|
|
|
+ newLogger := log.New(buffer, "prefix: ", 0)
|
|
|
+ SetLogger(newLogger)
|
|
|
+
|
|
|
+ // Restore the logger.
|
|
|
+ defer SetLogger(oldLogger)
|
|
|
+
|
|
|
+ // Make a new DSN that uses the MySQL socket file and a bad password, which
|
|
|
+ // we can make by simply appending any character to the real password.
|
|
|
+ badPass := pass + "x"
|
|
|
+ socket := ""
|
|
|
+ if prot == "unix" {
|
|
|
+ socket = addr
|
|
|
+ } else {
|
|
|
+ // Get socket file from MySQL.
|
|
|
+ err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error on SELECT @@socket: %s", err.Error())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ t.Logf("socket: %s", socket)
|
|
|
+ badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
|
|
|
+ db, err := sql.Open("mysql", badDSN)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ // Connect to MySQL for real. This will cause an auth failure.
|
|
|
+ err = db.Ping()
|
|
|
+ if err == nil {
|
|
|
+ t.Error("expected Ping() to return an error")
|
|
|
+ }
|
|
|
+
|
|
|
+ // The driver should not log anything.
|
|
|
+ if actual := buffer.String(); actual != "" {
|
|
|
+ t.Errorf("expected no output, got %q", actual)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|