@@ -9,9 +9,12 @@ import (
99
1010 _ "github.com/go-sql-driver/mysql"
1111 "github.com/jackc/pgx/v5/pgxpool"
12+ _ "github.com/ncruces/go-sqlite3/driver"
13+ _ "github.com/ncruces/go-sqlite3/embed"
1214
1315 "github.com/sqlc-dev/sqlc/internal/engine/dolphin"
1416 "github.com/sqlc-dev/sqlc/internal/engine/postgresql"
17+ "github.com/sqlc-dev/sqlc/internal/engine/sqlite"
1518)
1619
1720// PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool.
@@ -39,16 +42,13 @@ func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query strin
3942 return columns , nil
4043}
4144
42- // MySQLColumnGetter implements ColumnGetter for MySQL using database/sql.
43- type MySQLColumnGetter struct {
45+ // SQLColumnGetter implements ColumnGetter for MySQL and SQLite using database/sql.
46+ type SQLColumnGetter struct {
4447 db * sql.DB
4548}
4649
47- func (g * MySQLColumnGetter ) GetColumnNames (ctx context.Context , query string ) ([]string , error ) {
48- // Use LIMIT 0 to get column metadata without fetching rows
49- limitedQuery := query
50- // For SELECT queries, add LIMIT 0 if not already present
51- rows , err := g .db .QueryContext (ctx , limitedQuery )
50+ func (g * SQLColumnGetter ) GetColumnNames (ctx context.Context , query string ) ([]string , error ) {
51+ rows , err := g .db .QueryContext (ctx , query )
5252 if err != nil {
5353 return nil , err
5454 }
@@ -242,7 +242,7 @@ func TestExpandMySQL(t *testing.T) {
242242 parser := dolphin .NewParser ()
243243
244244 // Create the expander
245- colGetter := & MySQLColumnGetter {db : db }
245+ colGetter := & SQLColumnGetter {db : db }
246246 exp := New (colGetter , parser , parser )
247247
248248 tests := []struct {
@@ -304,3 +304,92 @@ func TestExpandMySQL(t *testing.T) {
304304 })
305305 }
306306}
307+
308+ func TestExpandSQLite (t * testing.T ) {
309+ ctx := context .Background ()
310+
311+ // Create an in-memory SQLite database
312+ db , err := sql .Open ("sqlite3" , ":memory:" )
313+ if err != nil {
314+ t .Fatalf ("could not open SQLite: %v" , err )
315+ }
316+ defer db .Close ()
317+
318+ // Create a test table
319+ _ , err = db .ExecContext (ctx , `
320+ CREATE TABLE authors (
321+ id INTEGER PRIMARY KEY AUTOINCREMENT,
322+ name TEXT NOT NULL,
323+ bio TEXT
324+ )
325+ ` )
326+ if err != nil {
327+ t .Fatalf ("failed to create test table: %v" , err )
328+ }
329+
330+ // Create the parser which also implements format.Dialect
331+ parser := sqlite .NewParser ()
332+
333+ // Create the expander
334+ colGetter := & SQLColumnGetter {db : db }
335+ exp := New (colGetter , parser , parser )
336+
337+ tests := []struct {
338+ name string
339+ query string
340+ expected string
341+ }{
342+ {
343+ name : "simple select star" ,
344+ query : "SELECT * FROM authors" ,
345+ expected : "SELECT id,name,bio FROM authors;" ,
346+ },
347+ {
348+ name : "select with no star" ,
349+ query : "SELECT id, name FROM authors" ,
350+ expected : "SELECT id, name FROM authors" , // No change, returns original
351+ },
352+ {
353+ name : "select star with where clause" ,
354+ query : "SELECT * FROM authors WHERE id = 1" ,
355+ expected : "SELECT id,name,bio FROM authors WHERE id = 1;" ,
356+ },
357+ {
358+ name : "double star" ,
359+ query : "SELECT *, * FROM authors" ,
360+ expected : "SELECT id,name,bio,id,name,bio FROM authors;" ,
361+ },
362+ {
363+ name : "table qualified star" ,
364+ query : "SELECT authors.* FROM authors" ,
365+ expected : "SELECT authors.id,authors.name,authors.bio FROM authors;" ,
366+ },
367+ {
368+ name : "star in middle of columns" ,
369+ query : "SELECT id, *, name FROM authors" ,
370+ expected : "SELECT id,id,name,bio,name FROM authors;" ,
371+ },
372+ {
373+ name : "count star not expanded" ,
374+ query : "SELECT COUNT(*) FROM authors" ,
375+ expected : "SELECT COUNT(*) FROM authors" , // No change - COUNT(*) should not be expanded
376+ },
377+ {
378+ name : "count star with other columns" ,
379+ query : "SELECT COUNT(*), name FROM authors GROUP BY name" ,
380+ expected : "SELECT COUNT(*), name FROM authors GROUP BY name" , // No change
381+ },
382+ }
383+
384+ for _ , tc := range tests {
385+ t .Run (tc .name , func (t * testing.T ) {
386+ result , err := exp .Expand (ctx , tc .query )
387+ if err != nil {
388+ t .Fatalf ("Expand failed: %v" , err )
389+ }
390+ if result != tc .expected {
391+ t .Errorf ("expected %q, got %q" , tc .expected , result )
392+ }
393+ })
394+ }
395+ }
0 commit comments