diff --git a/web/websocket/hub.go b/web/websocket/hub.go index 2b8773cf..7299f403 100644 --- a/web/websocket/hub.go +++ b/web/websocket/hub.go @@ -29,11 +29,23 @@ const ( enqueueTimeout = 100 * time.Millisecond clientSendQueue = 512 // ~50s of buffering for a momentarily slow browser. hubBroadcastQueue = 2048 // Headroom for cron-storm + admin-mutation bursts. - hubControlQueue = 64 // Backlog for register/unregister bursts (page reloads, disconnect storms). + hubOpsQueue = 128 // Backlog for register+unregister bursts (page reloads, disconnect storms). minBroadcastInterval = 250 * time.Millisecond hubRestartAttempts = 3 ) +type clientOpKind int + +const ( + opRegister clientOpKind = iota + opUnregister +) + +type clientOp struct { + kind clientOpKind + c *Client +} + // NewClient builds a Client ready for hub registration. func NewClient(id string) *Client { return &Client{ @@ -58,13 +70,12 @@ type Client struct { // Hub fan-outs messages to all connected clients. type Hub struct { - clients map[*Client]struct{} - broadcast chan []byte - register chan *Client - unregister chan *Client - mu sync.RWMutex - ctx context.Context - cancel context.CancelFunc + clients map[*Client]struct{} + broadcast chan []byte + ops chan clientOp + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc throttleMu sync.Mutex lastBroadcast map[MessageType]time.Time @@ -76,8 +87,7 @@ func NewHub() *Hub { return &Hub{ clients: make(map[*Client]struct{}), broadcast: make(chan []byte, hubBroadcastQueue), - register: make(chan *Client, hubControlQueue), - unregister: make(chan *Client, hubControlQueue), + ops: make(chan clientOp, hubOpsQueue), ctx: ctx, cancel: cancel, lastBroadcast: make(map[MessageType]time.Time), @@ -145,21 +155,20 @@ func (h *Hub) runOnce() (stopped bool) { h.shutdown() return true - case c := <-h.register: - if c == nil { + case op := <-h.ops: + if op.c == nil { continue } - h.mu.Lock() - h.clients[c] = struct{}{} - n := len(h.clients) - h.mu.Unlock() - logger.Debugf("WebSocket client connected: %s (total: %d)", c.ID, n) - - case c := <-h.unregister: - if c == nil { - continue + switch op.kind { + case opRegister: + h.mu.Lock() + h.clients[op.c] = struct{}{} + n := len(h.clients) + h.mu.Unlock() + logger.Debugf("WebSocket client connected: %s (total: %d)", op.c.ID, n) + case opUnregister: + h.removeClient(op.c) } - h.removeClient(c) case msg := <-h.broadcast: h.fanout(msg) @@ -321,29 +330,29 @@ func (h *Hub) Register(c *Client) { return } select { - case h.register <- c: + case h.ops <- clientOp{kind: opRegister, c: c}: case <-h.ctx.Done(): } } -// Unregister removes a client from the hub. Fast path queues for the hub -// goroutine; if the channel is saturated (disconnect storm) we fall back -// to a direct removal under the write lock so dead clients aren't left in -// the registry waiting for their Send buffer to fill (minutes of wasted -// fanout work at low broadcast rates). +// Unregister removes a client from the hub. Sends through the same ordered +// ops channel as Register so a register-then-unregister sequence from one +// goroutine is processed in program order — otherwise an unregister could +// land in the map before its register and silently no-op, leaking the entry. // -// Direct removal is safe from any caller: external goroutines (read/write -// pumps) hold no hub locks, and the hub goroutine itself never holds h.mu -// when it calls Unregister — fanout releases its RLock before per-client -// sends, so we can't self-deadlock here. +// On a saturated ops channel (disconnect storm) we fall back to a bounded +// timeout drop rather than direct removal: a direct delete on a not-yet- +// registered client is precisely the ordering bug we fix here. Stragglers +// get evicted by fanout when their Send buffer fills. func (h *Hub) Unregister(c *Client) { if h == nil || c == nil { return } select { - case h.unregister <- c: - default: - h.removeClient(c) + case h.ops <- clientOp{kind: opUnregister, c: c}: + case <-time.After(enqueueTimeout): + logger.Warningf("WebSocket ops channel full, dropping unregister for %s", c.ID) + case <-h.ctx.Done(): } }