@@ -5,11 +5,13 @@ import (
55 "encoding/json"
66 "errors"
77 "fmt"
8+ "log"
89 "strings"
910 "time"
1011
1112 "github.com/google/uuid"
1213 "github.com/jackc/pgx/v5"
14+ "github.com/jackc/pgx/v5/pgconn"
1315 "github.com/jackc/pgx/v5/pgxpool"
1416
1517 apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
@@ -20,6 +22,21 @@ type PostgreSQL struct {
2022 pool * pgxpool.Pool
2123}
2224
25+ // Executor is an interface for executing queries (satisfied by both pgx.Tx and pgxpool.Pool)
26+ type Executor interface {
27+ Exec (ctx context.Context , sql string , arguments ... any ) (pgconn.CommandTag , error )
28+ Query (ctx context.Context , sql string , args ... any ) (pgx.Rows , error )
29+ QueryRow (ctx context.Context , sql string , args ... any ) pgx.Row
30+ }
31+
32+ // getExecutor returns the appropriate executor (transaction or pool)
33+ func (db * PostgreSQL ) getExecutor (tx pgx.Tx ) Executor {
34+ if tx != nil {
35+ return tx
36+ }
37+ return db .pool
38+ }
39+
2340// NewPostgreSQL creates a new instance of the PostgreSQL database
2441func NewPostgreSQL (ctx context.Context , connectionURI string ) (* PostgreSQL , error ) {
2542 // Parse connection config for pool settings
@@ -65,6 +82,7 @@ func NewPostgreSQL(ctx context.Context, connectionURI string) (*PostgreSQL, erro
6582//nolint:cyclop // Database filtering logic is inherently complex but clear
6683func (db * PostgreSQL ) List (
6784 ctx context.Context ,
85+ tx pgx.Tx ,
6886 filter * ServerFilter ,
6987 cursor string ,
7088 limit int ,
@@ -142,7 +160,7 @@ func (db *PostgreSQL) List(
142160 ` , whereClause , argIndex )
143161 args = append (args , limit )
144162
145- rows , err := db .pool .Query (ctx , query , args ... )
163+ rows , err := db .getExecutor ( tx ) .Query (ctx , query , args ... )
146164 if err != nil {
147165 return nil , "" , fmt .Errorf ("failed to query servers: %w" , err )
148166 }
@@ -182,7 +200,7 @@ func (db *PostgreSQL) List(
182200 return results , nextCursor , nil
183201}
184202
185- func (db * PostgreSQL ) GetByVersionID (ctx context.Context , versionID string ) (* apiv0.ServerJSON , error ) {
203+ func (db * PostgreSQL ) GetByVersionID (ctx context.Context , tx pgx. Tx , versionID string ) (* apiv0.ServerJSON , error ) {
186204 if ctx .Err () != nil {
187205 return nil , ctx .Err ()
188206 }
@@ -194,8 +212,7 @@ func (db *PostgreSQL) GetByVersionID(ctx context.Context, versionID string) (*ap
194212 `
195213
196214 var valueJSON []byte
197-
198- err := db .pool .QueryRow (ctx , query , versionID ).Scan (& valueJSON )
215+ err := db .getExecutor (tx ).QueryRow (ctx , query , versionID ).Scan (& valueJSON )
199216
200217 if err != nil {
201218 if errors .Is (err , pgx .ErrNoRows ) {
@@ -214,7 +231,7 @@ func (db *PostgreSQL) GetByVersionID(ctx context.Context, versionID string) (*ap
214231}
215232
216233// GetByServerID retrieves the latest version of a server by server ID
217- func (db * PostgreSQL ) GetByServerID (ctx context.Context , serverID string ) (* apiv0.ServerJSON , error ) {
234+ func (db * PostgreSQL ) GetByServerID (ctx context.Context , tx pgx. Tx , serverID string ) (* apiv0.ServerJSON , error ) {
218235 if ctx .Err () != nil {
219236 return nil , ctx .Err ()
220237 }
@@ -228,8 +245,7 @@ func (db *PostgreSQL) GetByServerID(ctx context.Context, serverID string) (*apiv
228245 `
229246
230247 var valueJSON []byte
231-
232- err := db .pool .QueryRow (ctx , query , serverID ).Scan (& valueJSON )
248+ err := db .getExecutor (tx ).QueryRow (ctx , query , serverID ).Scan (& valueJSON )
233249
234250 if err != nil {
235251 if errors .Is (err , pgx .ErrNoRows ) {
@@ -248,7 +264,7 @@ func (db *PostgreSQL) GetByServerID(ctx context.Context, serverID string) (*apiv
248264}
249265
250266// GetByServerIDAndVersion retrieves a specific version of a server by server ID and version
251- func (db * PostgreSQL ) GetByServerIDAndVersion (ctx context.Context , serverID string , version string ) (* apiv0.ServerJSON , error ) {
267+ func (db * PostgreSQL ) GetByServerIDAndVersion (ctx context.Context , tx pgx. Tx , serverID string , version string ) (* apiv0.ServerJSON , error ) {
252268 if ctx .Err () != nil {
253269 return nil , ctx .Err ()
254270 }
@@ -261,8 +277,7 @@ func (db *PostgreSQL) GetByServerIDAndVersion(ctx context.Context, serverID stri
261277 `
262278
263279 var valueJSON []byte
264-
265- err := db .pool .QueryRow (ctx , query , serverID , version ).Scan (& valueJSON )
280+ err := db .getExecutor (tx ).QueryRow (ctx , query , serverID , version ).Scan (& valueJSON )
266281
267282 if err != nil {
268283 if errors .Is (err , pgx .ErrNoRows ) {
@@ -281,7 +296,7 @@ func (db *PostgreSQL) GetByServerIDAndVersion(ctx context.Context, serverID stri
281296}
282297
283298// GetAllVersionsByServerID retrieves all versions of a server by server ID
284- func (db * PostgreSQL ) GetAllVersionsByServerID (ctx context.Context , serverID string ) ([]* apiv0.ServerJSON , error ) {
299+ func (db * PostgreSQL ) GetAllVersionsByServerID (ctx context.Context , tx pgx. Tx , serverID string ) ([]* apiv0.ServerJSON , error ) {
285300 if ctx .Err () != nil {
286301 return nil , ctx .Err ()
287302 }
@@ -293,7 +308,7 @@ func (db *PostgreSQL) GetAllVersionsByServerID(ctx context.Context, serverID str
293308 ORDER BY (value->'_meta'->'io.modelcontextprotocol.registry/official'->>'publishedAt')::timestamp DESC
294309 `
295310
296- rows , err := db .pool .Query (ctx , query , serverID )
311+ rows , err := db .getExecutor ( tx ) .Query (ctx , query , serverID )
297312 if err != nil {
298313 return nil , fmt .Errorf ("failed to query server versions: %w" , err )
299314 }
@@ -328,8 +343,8 @@ func (db *PostgreSQL) GetAllVersionsByServerID(ctx context.Context, serverID str
328343 return results , nil
329344}
330345
331- // CreateServer adds a new server to the database
332- func (db * PostgreSQL ) CreateServer (ctx context.Context , server * apiv0.ServerJSON ) (* apiv0.ServerJSON , error ) {
346+ // CreateServer inserts a new server version
347+ func (db * PostgreSQL ) CreateServer (ctx context.Context , tx pgx. Tx , server * apiv0.ServerJSON ) (* apiv0.ServerJSON , error ) {
333348 if ctx .Err () != nil {
334349 return nil , ctx .Err ()
335350 }
@@ -339,11 +354,9 @@ func (db *PostgreSQL) CreateServer(ctx context.Context, server *apiv0.ServerJSON
339354 return nil , fmt .Errorf ("server must have registry metadata with ServerID and VersionID" )
340355 }
341356
342- serverID := server .Meta .Official .ServerID
343357 versionID := server .Meta .Official .VersionID
344-
345- if serverID == "" || versionID == "" {
346- return nil , fmt .Errorf ("server must have both ServerID and VersionID in registry metadata" )
358+ if versionID == "" {
359+ return nil , fmt .Errorf ("server must have VersionID in registry metadata" )
347360 }
348361
349362 // Marshal the complete server to JSONB
@@ -352,13 +365,14 @@ func (db *PostgreSQL) CreateServer(ctx context.Context, server *apiv0.ServerJSON
352365 return nil , fmt .Errorf ("failed to marshal server JSON: %w" , err )
353366 }
354367
355- // Insert into servers table with new schema (only version_id column, serverId is in JSON)
356- query := `
368+ // Insert the new version
369+ insertQuery := `
357370 INSERT INTO servers (version_id, value)
358371 VALUES ($1, $2)
359372 `
360373
361- _ , err = db .pool .Exec (ctx , query , versionID , valueJSON )
374+ _ , err = db .getExecutor (tx ).Exec (ctx , insertQuery , versionID , valueJSON )
375+
362376 if err != nil {
363377 return nil , fmt .Errorf ("failed to insert server: %w" , err )
364378 }
@@ -367,7 +381,7 @@ func (db *PostgreSQL) CreateServer(ctx context.Context, server *apiv0.ServerJSON
367381}
368382
369383// UpdateServer updates an existing server record with new server details
370- func (db * PostgreSQL ) UpdateServer (ctx context.Context , id string , server * apiv0.ServerJSON ) (* apiv0.ServerJSON , error ) {
384+ func (db * PostgreSQL ) UpdateServer (ctx context.Context , tx pgx. Tx , id string , server * apiv0.ServerJSON ) (* apiv0.ServerJSON , error ) {
371385 if ctx .Err () != nil {
372386 return nil , ctx .Err ()
373387 }
@@ -390,7 +404,7 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, id string, server *apiv0
390404 WHERE version_id = $2
391405 `
392406
393- result , err := db .pool .Exec (ctx , query , valueJSON , id )
407+ result , err := db .getExecutor ( tx ) .Exec (ctx , query , valueJSON , id )
394408 if err != nil {
395409 return nil , fmt .Errorf ("failed to update server: %w" , err )
396410 }
@@ -402,47 +416,56 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, id string, server *apiv0
402416 return server , nil
403417}
404418
405- // WithPublishLock executes a function with an exclusive advisory lock for publishing a server
406- // This prevents race conditions when multiple versions are published concurrently
407- func (db * PostgreSQL ) WithPublishLock (ctx context.Context , serverName string , fn func (ctx context.Context ) error ) error {
419+ // InTransaction executes a function within a database transaction
420+ func (db * PostgreSQL ) InTransaction (ctx context.Context , fn func (ctx context.Context , tx pgx.Tx ) error ) error {
408421 if ctx .Err () != nil {
409422 return ctx .Err ()
410423 }
411424
412- // Begin a transaction
413425 tx , err := db .pool .Begin (ctx )
414426 if err != nil {
415427 return fmt .Errorf ("failed to begin transaction: %w" , err )
416428 }
429+ //nolint:contextcheck // Intentionally using separate context for rollback to ensure cleanup even if request is cancelled
417430 defer func () {
418- _ = tx .Rollback (ctx )
431+ rollbackCtx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
432+ defer cancel ()
433+ if rbErr := tx .Rollback (rollbackCtx ); rbErr != nil && ! errors .Is (rbErr , pgx .ErrTxClosed ) {
434+ log .Printf ("failed to rollback transaction: %v" , rbErr )
435+ }
419436 }()
420437
421- // Acquire advisory lock based on server name hash
422- // Using pg_advisory_xact_lock which auto-releases on transaction end
423- lockID := hashServerName (serverName )
424- _ , err = tx .Exec (ctx , "SELECT pg_advisory_xact_lock($1)" , lockID )
425- if err != nil {
426- return fmt .Errorf ("failed to acquire publish lock: %w" , err )
427- }
428-
429- // Execute the function
430- if err := fn (ctx ); err != nil {
438+ if err := fn (ctx , tx ); err != nil {
431439 return err
432440 }
433441
434- // Commit the transaction (which also releases the lock)
435442 if err := tx .Commit (ctx ); err != nil {
436443 return fmt .Errorf ("failed to commit transaction: %w" , err )
437444 }
438445
439446 return nil
440447}
441448
449+ // AcquirePublishLock acquires an exclusive advisory lock for publishing a server
450+ // This prevents race conditions when multiple versions are published concurrently
451+ // Using pg_advisory_xact_lock which auto-releases on transaction end
452+ func (db * PostgreSQL ) AcquirePublishLock (ctx context.Context , tx pgx.Tx , serverName string ) error {
453+ if ctx .Err () != nil {
454+ return ctx .Err ()
455+ }
456+
457+ lockID := hashServerName (serverName )
458+
459+ if _ , err := db .getExecutor (tx ).Exec (ctx , "SELECT pg_advisory_xact_lock($1)" , lockID ); err != nil {
460+ return fmt .Errorf ("failed to acquire publish lock: %w" , err )
461+ }
462+
463+ return nil
464+ }
465+
442466// hashServerName creates a consistent hash of the server name for advisory locking
443467// We use FNV-1a hash and mask to 63 bits to fit in PostgreSQL's bigint range
444468func hashServerName (name string ) int64 {
445- // FNV-1a 64-bit hash
446469 const (
447470 offset64 = 14695981039346656037
448471 prime64 = 1099511628211
@@ -452,7 +475,6 @@ func hashServerName(name string) int64 {
452475 hash ^= uint64 (name [i ])
453476 hash *= prime64
454477 }
455- // Use only 63 bits to ensure positive int64
456478 //nolint:gosec // Intentional conversion with masking to 63 bits
457479 return int64 (hash & 0x7FFFFFFFFFFFFFFF )
458480}
0 commit comments