Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Expand Down
41 changes: 41 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) {
}
}

func BenchmarkSelect10000rows(b *testing.B) {
db := initDB(b, false)
defer db.Close()

// Check if we're using MariaDB
var version string
err := db.QueryRow("SELECT @@version").Scan(&version)
if err != nil {
b.Fatalf("Failed to get server version: %v", err)
}

if !strings.Contains(strings.ToLower(version), "mariadb") {
b.Skip("Skipping benchmark as it requires MariaDB sequence table")
return
}

b.StartTimer()
stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000")
if err != nil {
b.Fatalf("Failed to prepare statement: %v", err)
}
defer stmt.Close()
for n := 0; n < b.N; n++ {
rows, err := stmt.Query()
if err != nil {
b.Fatalf("Failed to query 10000rows: %v", err)
}

var id int64
for rows.Next() {
err = rows.Scan(&id)
if err != nil {
rows.Close()
b.Fatalf("Failed to scan row: %v", err)
}
}
rows.Close()
}
b.StopTimer()
}

func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
Expand Down
2 changes: 2 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by
conn := new(mockConn)
conn.data = compressedPacket
mc.netConn = conn
mc.readNextFunc = mc.compIO.readNext
mc.readFunc = conn.Read

uncompressedPacket, err := mc.readPacket()
if err != nil {
Expand Down
14 changes: 3 additions & 11 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type mysqlConn struct {
compressSequence uint8
parseTime bool
compress bool
readFunc func([]byte) (int, error)
readNextFunc func(int, readerFunc) ([]byte, error)

// for context support (Go 1.8+)
watching bool
Expand All @@ -64,16 +66,6 @@ func (mc *mysqlConn) log(v ...any) {
mc.cfg.Logger.Print(v...)
}

func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
to := mc.cfg.ReadTimeout
if to > 0 {
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Read(b)
}

func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
to := mc.cfg.WriteTimeout
if to > 0 {
Expand Down Expand Up @@ -247,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
// can not take the buffer. Something must be wrong with the connection
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
// So it's safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
Expand Down
7 changes: 6 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ import (
)

func TestInterpolateParams(t *testing.T) {
buf := newBuffer()
nc := &net.TCPConn{}
mc := &mysqlConn{
buf: newBuffer(),
buf: buf,
netConn: nc,
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
readNextFunc: buf.readNext,
readFunc: nc.Read,
}

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
Expand Down
18 changes: 18 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"strconv"
"strings"
"time"
)

type connector struct {
Expand Down Expand Up @@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {

mc.buf = newBuffer()

// setting readNext/read functions
mc.readNextFunc = mc.buf.readNext

// Initialize read function based on configuration
if mc.cfg.ReadTimeout > 0 {
mc.readFunc = func(b []byte) (int, error) {
deadline := time.Now().Add(mc.cfg.ReadTimeout)
if err := mc.netConn.SetReadDeadline(deadline); err != nil {
return 0, err
}
return mc.netConn.Read(b)
}
} else {
mc.readFunc = mc.netConn.Read
}

// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
Expand Down Expand Up @@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
mc.compress = true
mc.compIO = newCompIO(mc)
mc.readNextFunc = mc.compIO.readNext
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
Expand Down
52 changes: 47 additions & 5 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) {
}

runTests(t, tdsn, func(dbt *DBTest) {
// see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation
// when character_set_collations is set for the charset, it overrides the default collation
// so we need to check if the default collation is overridden
forceExpected := expected
var defaultCollations string
err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations)
if err == nil {
// Query succeeded, need to check if we should override expected collation
collationMap := make(map[string]string)
pairs := strings.Split(defaultCollations, ",")
for _, pair := range pairs {
parts := strings.Split(pair, "=")
if len(parts) == 2 {
collationMap[parts[0]] = parts[1]
}
}

// Get charset prefix from expected collation
parts := strings.Split(expected, "_")
if len(parts) > 0 {
charset := parts[0]
if newCollation, ok := collationMap[charset]; ok {
forceExpected = newCollation
}
}
}

var got string
if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil {
dbt.Fatal(err)
}

if got != expected {
dbt.Fatalf("expected connection collation %s but got %s", expected, got)
if forceExpected != expected {
if got != forceExpected {
dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got)
}
} else {
dbt.Fatalf("expected connection collation %s but got %s", expected, got)
}
}
})
}
Expand Down Expand Up @@ -1685,16 +1718,16 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
}

func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"}

// Regression test for timezone handling
tzTest := func(dbt *DBTest) {
// Create table
dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")

// Insert local time into database (should be converted)
usCentral, _ := time.LoadLocation("US/Central")
reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral)
newYorkTz, _ := time.LoadLocation("America/New_York")
reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz)
dbt.mustExec("INSERT INTO test VALUE (?)", reftime)

// Retrieve time from DB
Expand All @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) {
// Check that dates match
if reftime.Unix() != dbTime.Unix() {
dbt.Errorf("times do not match.\n")
dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime)
dbt.Errorf(" Now(UTC)=%v\n", dbTime)
}
}
Expand Down Expand Up @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) {

dbt := &DBTest{t, db}

var varName string
var varValue string
err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue)
if err != nil {
t.Fatalf("error: %s", err.Error())
}
if varValue != "ON" {
t.Skipf("Performance schema is not enabled. skipping")
}
queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
rows := dbt.mustQuery(queryString)
defer rows.Close()
Expand Down
10 changes: 3 additions & 7 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
invalidSequence := false

readNext := mc.buf.readNext
if mc.compress {
readNext = mc.compIO.readNext
}

for {
// read packet header
data, err := readNext(4, mc.readWithTimeout)
data, err := mc.readNextFunc(4, mc.readFunc)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
Expand Down Expand Up @@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}

// read packet body [pktLen bytes]
data, err = readNext(pktLen, mc.readWithTimeout)
data, err = mc.readNextFunc(pktLen, mc.readFunc)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
Expand Down Expand Up @@ -369,6 +364,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
return err
}
mc.netConn = tlsConn
mc.readFunc = mc.netConn.Read
}

// User [null terminated string]
Expand Down
47 changes: 31 additions & 16 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn)
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector := newConnector(NewConfig())
buf := newBuffer()
mc := &mysqlConn{
buf: newBuffer(),
buf: buf,
cfg: connector.cfg,
connector: connector,
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
sequence: sequence,
readNextFunc: buf.readNext,
readFunc: conn.Read,
}
return conn, mc
}

func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: NewConfig(),
netConn: conn,
buf: buf,
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
Expand Down Expand Up @@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) {

func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: NewConfig(),
netConn: conn,
buf: buf,
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

data := make([]byte, maxPacketSize*2+4*3)
Expand Down Expand Up @@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) {

func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
closech: make(chan struct{}),
cfg: NewConfig(),
netConn: conn,
buf: buf,
closech: make(chan struct{}),
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

// illegal empty (stand-alone) packet
Expand Down Expand Up @@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) {
// not-NUL terminated plugin_name in init packet
func TestRegression801(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
netConn: conn,
buf: buf,
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
Expand Down