88// For details see https://sqlite.org/c3ref/open.html#urifilenames.
99//
1010//
11+ // Initializing connections or tracing
12+ //
13+ // If you want to do initial configuration of a connection, or enable
14+ // tracing, use the Connector function:
15+ //
16+ // connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error {
17+ // called++
18+ // stmt, err := conn.PrepareContext(ctx, "PRAGMA journal_mode=WAL;")
19+ // if err != nil {
20+ // return err
21+ // }
22+ // _, err = stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
23+ // return err
24+ // }
25+ // db, err = sql.OpenDB(sqlite.Connector(sqliteURI, connInitFunc, nil))
26+ //
27+ //
1128// Memory Mode
1229//
1330// In-memory databases are popular for tests.
@@ -93,6 +110,12 @@ var Open sqliteh.OpenFunc = func(string, sqliteh.OpenFlags, string) (sqliteh.DB,
93110// code between calls to rows.Next.
94111type TraceFunc func (prepCtx context.Context , query string , duration time.Duration , err error )
95112
113+ // ConnInitFunc is a function called by the driver on new connections.
114+ //
115+ // The conn can be used to execute queries.
116+ // Any error return closes the conn and passes the error to database/sql.
117+ type ConnInitFunc func (ctx context.Context , conn driver.ConnPrepareContext ) error
118+
96119// TimeFormat is the string format this driver uses to store
97120// microsecond-precision time in SQLite in text format.
98121const TimeFormat = "2006-01-02 15:04:05.000-0700"
@@ -108,16 +131,18 @@ func (d drv) OpenConnector(name string) (driver.Connector, error) {
108131 return & connector {name : name }, nil
109132}
110133
111- func Connector (sqliteURI string , traceFunc TraceFunc ) driver.Connector {
134+ func Connector (sqliteURI string , connInitFunc ConnInitFunc , traceFunc TraceFunc ) driver.Connector {
112135 return & connector {
113- name : sqliteURI ,
114- traceFunc : traceFunc ,
136+ name : sqliteURI ,
137+ traceFunc : traceFunc ,
138+ connInitFunc : connInitFunc ,
115139 }
116140}
117141
118142type connector struct {
119- name string
120- traceFunc TraceFunc
143+ name string
144+ traceFunc TraceFunc
145+ connInitFunc ConnInitFunc
121146}
122147
123148func (p * connector ) Driver () driver.Driver { return drv {} }
@@ -140,33 +165,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
140165 return nil , err
141166 }
142167
143- db .BusyTimeout (2 * time .Second ) // TODO: justify choice; make configurable?
144-
145- if err := db .AutoCheckpoint (0 ); err != nil {
146- db .Close ()
147- return nil , fmt .Errorf ("sqlite.Open: wal_autocheckpoint: %w" , err )
148- }
149-
150- if err := pragmaSynchronousNormal (db ); err != nil {
151- db .Close ()
152- return nil , fmt .Errorf ("sqlite.open: %w" , err )
153- }
154-
155168 c := & conn {db : db , traceFunc : p .traceFunc }
156- return c , nil
157- }
158-
159- func pragmaSynchronousNormal (db sqliteh.DB ) error {
160- const query = "PRAGMA synchronous=NORMAL;"
161- cstmt , _ , err := db .Prepare (query , 0 )
162- if err != nil {
163- return reserr (db , "Open" , query , err )
164- }
165- defer cstmt .Finalize ()
166- if _ , _ , _ , _ , err := cstmt .StepResult (); err != nil {
167- return reserr (db , "Open" , query , err )
169+ if p .connInitFunc != nil {
170+ if err := p .connInitFunc (ctx , c ); err != nil {
171+ db .Close ()
172+ return nil , fmt .Errorf ("sqlite.ConnInitFunc: %w" , err )
173+ }
168174 }
169- return nil
175+ return c , nil
170176}
171177
172178type conn struct {
0 commit comments