Skip to content

Commit 0000a5b

Browse files
committed
sqlite: conn init mechanism
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
1 parent ad8ab94 commit 0000a5b

File tree

2 files changed

+62
-31
lines changed

2 files changed

+62
-31
lines changed

sqlite.go

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@
88
// For details see https://sqlite.org/c3ref/open.html#urifilenames.
99
//
1010
//
11+
// Initializing connections or tracing
12+
//
13+
// If you want to do initial configuration of a connection, or enable
14+
// tracing, use the Connector function:
15+
//
16+
// connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error {
17+
// called++
18+
// stmt, err := conn.PrepareContext(ctx, "PRAGMA journal_mode=WAL;")
19+
// if err != nil {
20+
// return err
21+
// }
22+
// _, err = stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
23+
// return err
24+
// }
25+
// db, err = sql.OpenDB(sqlite.Connector(sqliteURI, connInitFunc, nil))
26+
//
27+
//
1128
// Memory Mode
1229
//
1330
// In-memory databases are popular for tests.
@@ -93,6 +110,12 @@ var Open sqliteh.OpenFunc = func(string, sqliteh.OpenFlags, string) (sqliteh.DB,
93110
// code between calls to rows.Next.
94111
type TraceFunc func(prepCtx context.Context, query string, duration time.Duration, err error)
95112

113+
// ConnInitFunc is a function called by the driver on new connections.
114+
//
115+
// The conn can be used to execute queries.
116+
// Any error return closes the conn and passes the error to database/sql.
117+
type ConnInitFunc func(ctx context.Context, conn driver.ConnPrepareContext) error
118+
96119
// TimeFormat is the string format this driver uses to store
97120
// microsecond-precision time in SQLite in text format.
98121
const TimeFormat = "2006-01-02 15:04:05.000-0700"
@@ -108,16 +131,18 @@ func (d drv) OpenConnector(name string) (driver.Connector, error) {
108131
return &connector{name: name}, nil
109132
}
110133

111-
func Connector(sqliteURI string, traceFunc TraceFunc) driver.Connector {
134+
func Connector(sqliteURI string, connInitFunc ConnInitFunc, traceFunc TraceFunc) driver.Connector {
112135
return &connector{
113-
name: sqliteURI,
114-
traceFunc: traceFunc,
136+
name: sqliteURI,
137+
traceFunc: traceFunc,
138+
connInitFunc: connInitFunc,
115139
}
116140
}
117141

118142
type connector struct {
119-
name string
120-
traceFunc TraceFunc
143+
name string
144+
traceFunc TraceFunc
145+
connInitFunc ConnInitFunc
121146
}
122147

123148
func (p *connector) Driver() driver.Driver { return drv{} }
@@ -140,33 +165,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
140165
return nil, err
141166
}
142167

143-
db.BusyTimeout(2 * time.Second) // TODO: justify choice; make configurable?
144-
145-
if err := db.AutoCheckpoint(0); err != nil {
146-
db.Close()
147-
return nil, fmt.Errorf("sqlite.Open: wal_autocheckpoint: %w", err)
148-
}
149-
150-
if err := pragmaSynchronousNormal(db); err != nil {
151-
db.Close()
152-
return nil, fmt.Errorf("sqlite.open: %w", err)
153-
}
154-
155168
c := &conn{db: db, traceFunc: p.traceFunc}
156-
return c, nil
157-
}
158-
159-
func pragmaSynchronousNormal(db sqliteh.DB) error {
160-
const query = "PRAGMA synchronous=NORMAL;"
161-
cstmt, _, err := db.Prepare(query, 0)
162-
if err != nil {
163-
return reserr(db, "Open", query, err)
164-
}
165-
defer cstmt.Finalize()
166-
if _, _, _, _, err := cstmt.StepResult(); err != nil {
167-
return reserr(db, "Open", query, err)
169+
if p.connInitFunc != nil {
170+
if err := p.connInitFunc(ctx, c); err != nil {
171+
db.Close()
172+
return nil, fmt.Errorf("sqlite.ConnInitFunc: %w", err)
173+
}
168174
}
169-
return nil
175+
return c, nil
170176
}
171177

172178
type conn struct {

sqlite_test.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package sqlite
77
import (
88
"context"
99
"database/sql"
10+
"database/sql/driver"
1011
"fmt"
1112
"os"
1213
"runtime"
@@ -65,7 +66,7 @@ func openTestDB(t testing.TB) *sql.DB {
6566

6667
func openTestDBTrace(t testing.TB, traceFunc TraceFunc) *sql.DB {
6768
t.Helper()
68-
db := sql.OpenDB(Connector("file:"+t.TempDir()+"/test.db", traceFunc))
69+
db := sql.OpenDB(Connector("file:"+t.TempDir()+"/test.db", nil, traceFunc))
6970
configDB(t, db)
7071
return db
7172
}
@@ -605,6 +606,30 @@ func TestTxnState(t *testing.T) {
605606
}
606607
}
607608

609+
func TestConnInit(t *testing.T) {
610+
called := 0
611+
uri := "file:" + t.TempDir() + "/test.db"
612+
connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error {
613+
called++
614+
stmt, err := conn.PrepareContext(ctx, "PRAGMA journal_mode=WAL;")
615+
if err != nil {
616+
return err
617+
}
618+
_, err = stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
619+
return err
620+
}
621+
db := sql.OpenDB(Connector(uri, connInitFunc, nil))
622+
conn, err := db.Conn(context.Background())
623+
if err != nil {
624+
t.Fatal(err)
625+
}
626+
if called == 0 {
627+
t.Fatal("called=0, want non-zero")
628+
}
629+
conn.Close()
630+
db.Close()
631+
}
632+
608633
func BenchmarkPersist(b *testing.B) {
609634
ctx := context.Background()
610635
db := openTestDB(b)

0 commit comments

Comments
 (0)