diff --git a/database/db.go b/database/db.go index 3a13845e..85e0eb75 100644 --- a/database/db.go +++ b/database/db.go @@ -89,6 +89,12 @@ func initModels() error { if err := pruneOrphanedClientInbounds(); err != nil { return err } + if IsPostgres() { + if err := resyncPostgresSequences(db, models); err != nil { + log.Printf("Error resyncing postgres sequences: %v", err) + return err + } + } return nil } diff --git a/database/migrate_data.go b/database/migrate_data.go index 11f1a40f..d76ff35a 100644 --- a/database/migrate_data.go +++ b/database/migrate_data.go @@ -123,21 +123,32 @@ func copyTable(src, dst *gorm.DB, mdl any) (int, error) { return total, err } -// resetPostgresSequences advances each table's id sequence past MAX(id), +// resetPostgresSequences advances each migrated table's id sequence past MAX(id), // otherwise the next INSERT-without-id would clash with copied rows. func resetPostgresSequences(dst *gorm.DB) error { - tables := []string{ - "users", "inbounds", "outbound_traffics", "settings", "inbound_client_ips", - "client_traffics", "history_of_seeders", "custom_geo_resources", "nodes", - "api_tokens", "client_records", "client_inbounds", "inbound_fallback_children", - } - for _, t := range tables { - // setval is a no-op if the table or its id sequence doesn't exist; we ignore errors per-table. - _ = dst.Exec(fmt.Sprintf( - `SELECT setval(pg_get_serial_sequence('%s','id'), COALESCE((SELECT MAX(id) FROM "%s"), 1), true) - WHERE pg_get_serial_sequence('%s','id') IS NOT NULL`, - t, t, t, - )).Error + return resyncPostgresSequences(dst, migrationModels()) +} + +// resyncPostgresSequences sets each model's id sequence to MAX(id) so the next +// auto-increment INSERT won't collide with an existing row. Table names are +// resolved from the models themselves (not hardcoded), so they always match the +// migrated tables. The statement is a no-op for tables without an id sequence +// (e.g. composite-PK tables), and idempotent on a healthy DB, so it is safe to +// run both after migration and on every Postgres startup. +func resyncPostgresSequences(db *gorm.DB, models []any) error { + for _, m := range models { + stmt := &gorm.Statement{DB: db} + if err := stmt.Parse(m); err != nil { + continue + } + t := stmt.Table + // t comes from the trusted model set parsed by GORM, not user input, so + // interpolating it as an identifier is safe. We ignore errors per-table. + _ = db.Exec( + `SELECT setval(pg_get_serial_sequence(?, 'id'), COALESCE((SELECT MAX(id) FROM "`+t+`"), 1), true) + WHERE pg_get_serial_sequence(?, 'id') IS NOT NULL`, + t, t, + ).Error } return nil }