From 5b9ed34009d6deb78b55f2f0bd436ed78b184dbb Mon Sep 17 00:00:00 2001 From: MHSanaei Date: Mon, 1 Jun 2026 22:54:56 +0200 Subject: [PATCH] fix(nodes): sum client traffic across nodes instead of overwriting A client shared across multiple nodes has a single email-keyed client_traffics row, but each node reports its cumulative up/down. setRemoteTrafficLocked overwrote the row with one node's cumulative, so non-owning nodes hit the create branch and OnConflict-DoNothing, silently dropping their traffic and under-counting the client. Make the shared row a pure accumulator (like the local path): a new node_client_traffics(node_id, email) baseline table stores each node's last cumulative; the node path converts cumulative to a per-node delta (clamped to the post-reset value on a negative delta) and does up = up + delta. First observation seeds the baseline and adds 0 so upgrades and newly-shared clients are not double-counted. Create-vs-accumulate now keys off global email existence. Baselines are cleaned in DelClientStat, the node sweeps, and NodeService.Delete. --- database/db.go | 1 + database/migrate_data.go | 1 + database/model/node_client_traffic.go | 9 + web/service/inbound.go | 109 +++++++--- web/service/node.go | 3 + web/service/node_client_traffic_sum_test.go | 209 ++++++++++++++++++++ 6 files changed, 306 insertions(+), 26 deletions(-) create mode 100644 database/model/node_client_traffic.go create mode 100644 web/service/node_client_traffic_sum_test.go diff --git a/database/db.go b/database/db.go index c43ffb9e..fc5a9739 100644 --- a/database/db.go +++ b/database/db.go @@ -72,6 +72,7 @@ func initModels() error { &model.ClientInbound{}, &model.ClientGroup{}, &model.InboundFallback{}, + &model.NodeClientTraffic{}, } for _, mdl := range models { if err := db.AutoMigrate(mdl); err != nil { diff --git a/database/migrate_data.go b/database/migrate_data.go index d76ff35a..0a29ffb4 100644 --- a/database/migrate_data.go +++ b/database/migrate_data.go @@ -36,6 +36,7 @@ func migrationModels() []any { &model.ClientRecord{}, &model.ClientInbound{}, &model.InboundFallback{}, + &model.NodeClientTraffic{}, } } diff --git a/database/model/node_client_traffic.go b/database/model/node_client_traffic.go new file mode 100644 index 00000000..08a0952f --- /dev/null +++ b/database/model/node_client_traffic.go @@ -0,0 +1,9 @@ +package model + +type NodeClientTraffic struct { + Id int `json:"id" gorm:"primaryKey;autoIncrement"` + NodeId int `json:"nodeId" gorm:"uniqueIndex:idx_node_email,priority:1;not null"` + Email string `json:"email" gorm:"uniqueIndex:idx_node_email,priority:2;not null"` + Up int64 `json:"up"` + Down int64 `json:"down"` +} diff --git a/web/service/inbound.go b/web/service/inbound.go index a5d17681..5aebf3b1 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -1251,6 +1251,18 @@ const resetGracePeriodMs int64 = 30000 // long after a real disconnect. const onlineGracePeriodMs int64 = 20000 +type nodeTrafficCounter struct { + Up int64 + Down int64 +} + +func (s *InboundService) upsertNodeBaseline(tx *gorm.DB, nodeID int, email string, up, down int64) error { + return tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "node_id"}, {Name: "email"}}, + DoUpdates: clause.AssignmentColumns([]string{"up", "down"}), + }).Create(&model.NodeClientTraffic{NodeId: nodeID, Email: email, Up: up, Down: down}).Error +} + func (s *InboundService) SetRemoteTraffic(nodeID int, snap *runtime.TrafficSnapshot) (bool, error) { var structuralChange bool err := submitTrafficWrite(func() error { @@ -1313,6 +1325,26 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi centralCSByEmail[centralClientStats[i].Email] = ¢ralClientStats[i] } + nodeBaselines := make(map[string]nodeTrafficCounter) + var baselineRows []model.NodeClientTraffic + if err := db.Model(&model.NodeClientTraffic{}). + Where("node_id = ?", nodeID). + Find(&baselineRows).Error; err != nil { + return false, err + } + for i := range baselineRows { + nodeBaselines[baselineRows[i].Email] = nodeTrafficCounter{Up: baselineRows[i].Up, Down: baselineRows[i].Down} + } + + var existingEmailsList []string + if err := db.Model(xray.ClientTraffic{}).Pluck("email", &existingEmailsList).Error; err != nil { + return false, err + } + existingEmails := make(map[string]struct{}, len(existingEmailsList)) + for _, e := range existingEmailsList { + existingEmails[e] = struct{}{} + } + var defaultUserId int if len(central) > 0 { defaultUserId = central[0].UserId @@ -1458,6 +1490,18 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi if _, kept := snapTags[c.Tag]; kept { continue } + var goneEmails []string + if err := tx.Model(xray.ClientTraffic{}). + Where("inbound_id = ?", c.Id). + Pluck("email", &goneEmails).Error; err != nil { + return false, err + } + if len(goneEmails) > 0 { + if err := tx.Where("node_id = ? AND email IN ?", nodeID, goneEmails). + Delete(&model.NodeClientTraffic{}).Error; err != nil { + return false, err + } + } if err := tx.Where("inbound_id = ?", c.Id). Delete(&xray.ClientTraffic{}).Error; err != nil { return false, err @@ -1481,17 +1525,22 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi if !ok { continue } - inGrace := c.LastTrafficResetTime > 0 && now-c.LastTrafficResetTime < resetGracePeriodMs - snapEmails := make(map[string]struct{}, len(snapIb.ClientStats)) for _, cs := range snapIb.ClientStats { snapEmails[cs.Email] = struct{}{} - existing := centralCS[csKey{c.Id, cs.Email}] - if existing == nil { - existing = centralCSByEmail[cs.Email] + base, seen := nodeBaselines[cs.Email] + var deltaUp, deltaDown int64 + if seen { + if deltaUp = cs.Up - base.Up; deltaUp < 0 { + deltaUp = cs.Up + } + if deltaDown = cs.Down - base.Down; deltaDown < 0 { + deltaDown = cs.Down + } } - if existing == nil { + + if _, rowExists := existingEmails[cs.Email]; !rowExists { row := &xray.ClientTraffic{ InboundId: c.Id, Email: cs.Email, @@ -1509,42 +1558,40 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi } centralCS[csKey{c.Id, cs.Email}] = row centralCSByEmail[cs.Email] = row + existingEmails[cs.Email] = struct{}{} structuralChange = true - continue - } - - if existing.Enable != cs.Enable || - existing.Total != cs.Total || - existing.ExpiryTime != cs.ExpiryTime || - existing.Reset != cs.Reset { - structuralChange = true - } - - if inGrace && cs.Up+cs.Down > 0 { - if err := tx.Exec( - `UPDATE client_traffics - SET enable = ?, total = ?, expiry_time = ?, reset = ? - WHERE email = ?`, - cs.Enable, cs.Total, cs.ExpiryTime, cs.Reset, cs.Email, - ).Error; err != nil { + if err := s.upsertNodeBaseline(tx, nodeID, cs.Email, cs.Up, cs.Down); err != nil { return false, err } + nodeBaselines[cs.Email] = nodeTrafficCounter{Up: cs.Up, Down: cs.Down} continue } + if existing := centralCSByEmail[cs.Email]; existing != nil && + (existing.Enable != cs.Enable || + existing.Total != cs.Total || + existing.ExpiryTime != cs.ExpiryTime || + existing.Reset != cs.Reset) { + structuralChange = true + } + if err := tx.Exec( fmt.Sprintf( `UPDATE client_traffics - SET up = ?, down = ?, enable = ?, total = ?, expiry_time = ?, reset = ?, + SET up = up + ?, down = down + ?, enable = ?, total = ?, expiry_time = ?, reset = ?, last_online = %s WHERE email = ?`, database.GreatestExpr("last_online", "?"), ), - cs.Up, cs.Down, cs.Enable, cs.Total, cs.ExpiryTime, cs.Reset, + deltaUp, deltaDown, cs.Enable, cs.Total, cs.ExpiryTime, cs.Reset, cs.LastOnline, cs.Email, ).Error; err != nil { return false, err } + if err := s.upsertNodeBaseline(tx, nodeID, cs.Email, cs.Up, cs.Down); err != nil { + return false, err + } + nodeBaselines[cs.Email] = nodeTrafficCounter{Up: cs.Up, Down: cs.Down} } for k, existing := range centralCS { @@ -1554,6 +1601,10 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi if _, kept := snapEmails[k.email]; kept { continue } + if err := tx.Where("node_id = ? AND email = ?", nodeID, existing.Email). + Delete(&model.NodeClientTraffic{}).Error; err != nil { + return false, err + } if err := tx.Where("inbound_id = ? AND email = ?", c.Id, existing.Email). Delete(&xray.ClientTraffic{}).Error; err != nil { return false, err @@ -1671,6 +1722,9 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi if err := tx.Where("email = ?", email).Delete(&xray.ClientTraffic{}).Error; err != nil { logger.Warningf("setRemoteTraffic: delete ClientTraffic %q failed: %v", email, err) } + if err := tx.Where("email = ?", email).Delete(&model.NodeClientTraffic{}).Error; err != nil { + logger.Warningf("setRemoteTraffic: delete NodeClientTraffic %q failed: %v", email, err) + } structuralChange = true } } @@ -2329,7 +2383,10 @@ func (s *InboundService) UpdateClientIPs(tx *gorm.DB, oldEmail string, newEmail } func (s *InboundService) DelClientStat(tx *gorm.DB, email string) error { - return tx.Where("email = ?", email).Delete(xray.ClientTraffic{}).Error + if err := tx.Where("email = ?", email).Delete(xray.ClientTraffic{}).Error; err != nil { + return err + } + return tx.Where("email = ?", email).Delete(&model.NodeClientTraffic{}).Error } func (s *InboundService) DelClientIPs(tx *gorm.DB, email string) error { diff --git a/web/service/node.go b/web/service/node.go index b6d2613f..0273ba08 100644 --- a/web/service/node.go +++ b/web/service/node.go @@ -233,6 +233,9 @@ func (s *NodeService) Delete(id int) error { if err := db.Where("id = ?", id).Delete(model.Node{}).Error; err != nil { return err } + if err := db.Where("node_id = ?", id).Delete(&model.NodeClientTraffic{}).Error; err != nil { + return err + } if mgr := runtime.GetManager(); mgr != nil { mgr.InvalidateNode(id) } diff --git a/web/service/node_client_traffic_sum_test.go b/web/service/node_client_traffic_sum_test.go new file mode 100644 index 00000000..0450dbfd --- /dev/null +++ b/web/service/node_client_traffic_sum_test.go @@ -0,0 +1,209 @@ +package service + +import ( + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v3/database" + "github.com/mhsanaei/3x-ui/v3/database/model" + "github.com/mhsanaei/3x-ui/v3/web/runtime" + "github.com/mhsanaei/3x-ui/v3/xray" + "gorm.io/gorm" +) + +func initTrafficTestDB(t *testing.T) *gorm.DB { + t.Helper() + dbDir := t.TempDir() + t.Setenv("XUI_DB_FOLDER", dbDir) + if err := database.InitDB(filepath.Join(dbDir, "x-ui.db")); err != nil { + t.Fatalf("InitDB: %v", err) + } + t.Cleanup(func() { _ = database.CloseDB() }) + return database.GetDB() +} + +func createNodeInbound(t *testing.T, db *gorm.DB, nodeID int, tag string, port int) { + t.Helper() + nid := nodeID + ib := &model.Inbound{UserId: 1, Tag: tag, Enable: true, Port: port, Protocol: model.VLESS, NodeID: &nid} + if err := db.Create(ib).Error; err != nil { + t.Fatalf("create node inbound %q: %v", tag, err) + } +} + +func syncNode(t *testing.T, svc *InboundService, nodeID int, tag string, stats ...xray.ClientTraffic) { + t.Helper() + snap := &runtime.TrafficSnapshot{ + Inbounds: []*model.Inbound{{Tag: tag, ClientStats: stats}}, + } + if _, err := svc.setRemoteTrafficLocked(nodeID, snap); err != nil { + t.Fatalf("setRemoteTrafficLocked node %d: %v", nodeID, err) + } +} + +func readTraffic(t *testing.T, db *gorm.DB, email string) xray.ClientTraffic { + t.Helper() + var ct xray.ClientTraffic + if err := db.Model(xray.ClientTraffic{}).Where("email = ?", email).First(&ct).Error; err != nil { + t.Fatalf("read client_traffics %q: %v", email, err) + } + return ct +} + +func assertUpDown(t *testing.T, ct xray.ClientTraffic, wantUp, wantDown int64, when string) { + t.Helper() + if ct.Up != wantUp || ct.Down != wantDown { + t.Errorf("%s: up=%d down=%d, want %d/%d", when, ct.Up, ct.Down, wantUp, wantDown) + } +} + +func TestTwoNodesShareEmail_SumsCorrectly(t *testing.T) { + db := initTrafficTestDB(t) + createNodeInbound(t, db, 1, "n1-in", 41001) + createNodeInbound(t, db, 2, "n2-in", 41002) + svc := &InboundService{} + + const email = "shared" + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 100, Down: 100, Enable: true}) + syncNode(t, svc, 2, "n2-in", xray.ClientTraffic{Email: email, Up: 200, Down: 200, Enable: true}) + + assertUpDown(t, readTraffic(t, db, email), 100, 100, "after baselines") + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 150, Down: 150, Enable: true}) + syncNode(t, svc, 2, "n2-in", xray.ClientTraffic{Email: email, Up: 260, Down: 260, Enable: true}) + + assertUpDown(t, readTraffic(t, db, email), 210, 210, "after both nodes grow") +} + +func TestSingleNode_MirrorsCorrectly(t *testing.T) { + db := initTrafficTestDB(t) + createNodeInbound(t, db, 1, "n1-in", 41001) + svc := &InboundService{} + + const email = "solo" + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 500, Down: 600, Enable: true}) + assertUpDown(t, readTraffic(t, db, email), 500, 600, "first sync") + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 700, Down: 800, Enable: true}) + assertUpDown(t, readTraffic(t, db, email), 700, 800, "second sync mirrors cumulative") +} + +func TestUpgrade_PreExistingRow_NoDoubleCount(t *testing.T) { + db := initTrafficTestDB(t) + createNodeInbound(t, db, 1, "n1-in", 41001) + svc := &InboundService{} + + const email = "legacy" + var ib model.Inbound + if err := db.Where("tag = ?", "n1-in").First(&ib).Error; err != nil { + t.Fatalf("load inbound: %v", err) + } + if err := db.Create(&xray.ClientTraffic{InboundId: ib.Id, Email: email, Up: 1000, Down: 2000, Enable: true}).Error; err != nil { + t.Fatalf("seed pre-existing row: %v", err) + } + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 1000, Down: 2000, Enable: true}) + assertUpDown(t, readTraffic(t, db, email), 1000, 2000, "first snapshot must not double-count") + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 1100, Down: 2100, Enable: true}) + assertUpDown(t, readTraffic(t, db, email), 1100, 2100, "growth after upgrade accrues") +} + +func TestNodeCounterReset_Clamped(t *testing.T) { + db := initTrafficTestDB(t) + createNodeInbound(t, db, 1, "n1-in", 41001) + svc := &InboundService{} + + const email = "restart" + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 900, Down: 900, Enable: true}) + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 950, Down: 950, Enable: true}) + assertUpDown(t, readTraffic(t, db, email), 950, 950, "before node reset") + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 50, Down: 50, Enable: true}) + ct := readTraffic(t, db, email) + if ct.Up < 0 || ct.Down < 0 { + t.Fatalf("row went negative after node reset: up=%d down=%d", ct.Up, ct.Down) + } + assertUpDown(t, ct, 1000, 1000, "after node counter reset (clamped)") +} + +func TestCentralReset_NoReAdd(t *testing.T) { + db := initTrafficTestDB(t) + createNodeInbound(t, db, 1, "n1-in", 41001) + createNodeInbound(t, db, 2, "n2-in", 41002) + svc := &InboundService{} + + const email = "reset" + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 100, Down: 100, Enable: true}) + syncNode(t, svc, 2, "n2-in", xray.ClientTraffic{Email: email, Up: 100, Down: 100, Enable: true}) + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 200, Down: 200, Enable: true}) + syncNode(t, svc, 2, "n2-in", xray.ClientTraffic{Email: email, Up: 200, Down: 200, Enable: true}) + + if err := db.Model(xray.ClientTraffic{}).Where("email = ?", email). + Updates(map[string]any{"up": 0, "down": 0}).Error; err != nil { + t.Fatalf("simulate central reset: %v", err) + } + + syncNode(t, svc, 1, "n1-in", xray.ClientTraffic{Email: email, Up: 210, Down: 210, Enable: true}) + syncNode(t, svc, 2, "n2-in", xray.ClientTraffic{Email: email, Up: 205, Down: 205, Enable: true}) + + assertUpDown(t, readTraffic(t, db, email), 15, 15, "after central reset only increments accrue") +} + +func TestDelClientStat_CleansNodeBaselines(t *testing.T) { + db := initTrafficTestDB(t) + svc := &InboundService{} + + const email = "gone" + if err := db.Create(&xray.ClientTraffic{InboundId: 1, Email: email, Enable: true}).Error; err != nil { + t.Fatalf("seed client_traffics: %v", err) + } + if err := db.Create(&model.NodeClientTraffic{NodeId: 1, Email: email, Up: 10, Down: 10}).Error; err != nil { + t.Fatalf("seed node baseline 1: %v", err) + } + if err := db.Create(&model.NodeClientTraffic{NodeId: 2, Email: email, Up: 20, Down: 20}).Error; err != nil { + t.Fatalf("seed node baseline 2: %v", err) + } + + if err := svc.DelClientStat(db, email); err != nil { + t.Fatalf("DelClientStat: %v", err) + } + + var cnt int64 + if err := db.Model(&model.NodeClientTraffic{}).Where("email = ?", email).Count(&cnt).Error; err != nil { + t.Fatalf("count baselines: %v", err) + } + if cnt != 0 { + t.Errorf("expected node baselines cleaned, found %d", cnt) + } +} + +func TestNodeDelete_CleansNodeBaselines(t *testing.T) { + db := initTrafficTestDB(t) + nodeSvc := NodeService{} + + if err := db.Create(&model.NodeClientTraffic{NodeId: 7, Email: "a", Up: 1, Down: 1}).Error; err != nil { + t.Fatalf("seed node 7 a: %v", err) + } + if err := db.Create(&model.NodeClientTraffic{NodeId: 7, Email: "b", Up: 2, Down: 2}).Error; err != nil { + t.Fatalf("seed node 7 b: %v", err) + } + if err := db.Create(&model.NodeClientTraffic{NodeId: 8, Email: "c", Up: 3, Down: 3}).Error; err != nil { + t.Fatalf("seed node 8 c: %v", err) + } + + if err := nodeSvc.Delete(7); err != nil { + t.Fatalf("NodeService.Delete(7): %v", err) + } + + var sevenCnt, eightCnt int64 + db.Model(&model.NodeClientTraffic{}).Where("node_id = ?", 7).Count(&sevenCnt) + db.Model(&model.NodeClientTraffic{}).Where("node_id = ?", 8).Count(&eightCnt) + if sevenCnt != 0 { + t.Errorf("node 7 baselines not cleaned: %d remain", sevenCnt) + } + if eightCnt != 1 { + t.Errorf("node 8 baseline should survive, found %d", eightCnt) + } +}