diff --git a/database/migrate_data.go b/database/migrate_data.go index f6f6e43f..49918c5b 100644 --- a/database/migrate_data.go +++ b/database/migrate_data.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" "log" @@ -109,14 +110,15 @@ func copyTable(src, dst *gorm.DB, mdl any) (int, error) { sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(mdl).Elem())) - // Resolve primary-key columns so paging is deterministic across successive - // LIMIT/OFFSET reads. The model set is trusted (not user input). stmt := &gorm.Statement{DB: src} if err := stmt.Parse(mdl); err != nil { return 0, err } order := strings.Join(stmt.Schema.PrimaryFieldDBNames, ", ") + table := stmt.Schema.Table + columns := stmt.Schema.DBNames + ctx := context.Background() total := 0 for offset := 0; ; offset += batchSize { batchPtr := reflect.New(sliceType) @@ -127,11 +129,24 @@ func copyTable(src, dst *gorm.DB, mdl any) (int, error) { if err := q.Find(batchPtr.Interface()).Error; err != nil { return total, err } - n := batchPtr.Elem().Len() + slice := batchPtr.Elem() + n := slice.Len() if n == 0 { break } - if err := dst.CreateInBatches(batchPtr.Interface(), 200).Error; err != nil { + + rows := make([]map[string]any, n) + for i := 0; i < n; i++ { + rv := reflect.Indirect(slice.Index(i)) + row := make(map[string]any, len(columns)) + for _, name := range columns { + value, _ := stmt.Schema.FieldsByDBName[name].ValueOf(ctx, rv) + row[name] = value + } + rows[i] = row + } + + if err := dst.Table(table).CreateInBatches(rows, 200).Error; err != nil { return total, err } total += n diff --git a/database/migrate_data_test.go b/database/migrate_data_test.go index 081d77d9..5c1d0c62 100644 --- a/database/migrate_data_test.go +++ b/database/migrate_data_test.go @@ -62,3 +62,78 @@ func TestMigrateData_CompositeKeyTableLargerThanBatch(t *testing.T) { t.Fatalf("client_inbounds rows = %d, want %d", got, n) } } + +func TestMigrateData_PreservesFalseDefaultedColumns(t *testing.T) { + dsn := os.Getenv("XUI_TEST_PG_DSN") + if dsn == "" { + t.Skip("set XUI_TEST_PG_DSN to a reachable Postgres to run this test") + } + + srcPath := t.TempDir() + "/x-ui.db" + src, err := gorm.Open(sqlite.Open(srcPath), &gorm.Config{Logger: logger.Discard}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + for _, m := range migrationModels() { + if err := src.AutoMigrate(m); err != nil { + t.Fatalf("automigrate %T: %v", m, err) + } + } + + if err := src.Create([]*model.ClientRecord{ + {Email: "on@example.com"}, + {Email: "off@example.com"}, + }).Error; err != nil { + t.Fatalf("seed clients: %v", err) + } + if err := src.Model(&model.ClientRecord{}).Where("email = ?", "off@example.com"). + Update("enable", false).Error; err != nil { + t.Fatalf("disable client: %v", err) + } + if err := src.Create(&model.Node{Name: "n-off", Address: "1.2.3.4", Port: 1, ApiToken: "tok"}).Error; err != nil { + t.Fatalf("seed node: %v", err) + } + if err := src.Model(&model.Node{}).Where("name = ?", "n-off"). + Update("enable", false).Error; err != nil { + t.Fatalf("disable node: %v", err) + } + if sqlDB, err := src.DB(); err == nil { + sqlDB.Close() + } + + dst, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Discard}) + if err != nil { + t.Fatalf("open postgres: %v", err) + } + if err := dst.Migrator().DropTable(migrationModels()...); err != nil { + t.Fatalf("drop tables: %v", err) + } + + if err := MigrateData(srcPath, dsn); err != nil { + t.Fatalf("MigrateData: %v", err) + } + + var off model.ClientRecord + if err := dst.Where("email = ?", "off@example.com").First(&off).Error; err != nil { + t.Fatalf("load disabled client: %v", err) + } + if off.Enable { + t.Fatalf("disabled client re-enabled after migration (enable=%v)", off.Enable) + } + + var on model.ClientRecord + if err := dst.Where("email = ?", "on@example.com").First(&on).Error; err != nil { + t.Fatalf("load enabled client: %v", err) + } + if !on.Enable { + t.Fatalf("enabled client wrongly disabled after migration") + } + + var node model.Node + if err := dst.Where("name = ?", "n-off").First(&node).Error; err != nil { + t.Fatalf("load node: %v", err) + } + if node.Enable { + t.Fatalf("disabled node re-enabled after migration") + } +}