diff --git a/internal/client/client.go b/internal/client/client.go index a18e795..a50340d 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -119,7 +119,11 @@ func RunWithReady(ctx context.Context, roomURL, keyHex string, socksPort int, du c.peers = append(c.peers, peer) peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - log.Printf("Client peer %d reconnected - resetting multiplexer state", peerID) + if dc == nil { + log.Printf("Client peer %d channel closed - resetting multiplexer state", peerID) + } else { + log.Printf("Client peer %d reconnected - resetting multiplexer state", peerID) + } c.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := c.cipher.Encrypt(frame) diff --git a/internal/server/server.go b/internal/server/server.go index e02af91..cd8483f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -136,7 +136,11 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string s.peers = append(s.peers, peer) peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID) + if dc == nil { + log.Printf("Server peer %d channel closed - resetting multiplexer state", peerID) + } else { + log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID) + } s.connMu.Lock() for sid, conn := range s.connections { diff --git a/internal/telemost/peer.go b/internal/telemost/peer.go index 667732d..60b1073 100644 --- a/internal/telemost/peer.go +++ b/internal/telemost/peer.go @@ -51,14 +51,17 @@ type Peer struct { lastReconnect time.Time reconnectCount int reconnectMu sync.Mutex + sessionMu sync.Mutex sendQueue chan []byte sendQueueClosed atomic.Bool closed atomic.Bool + reconnecting atomic.Bool telemetryActive atomic.Bool ackMu sync.Mutex ackWaiters map[string]chan struct{} onEnded func(string) trafficShape TrafficShape + sessionCloseCh chan struct{} wg sync.WaitGroup } @@ -94,16 +97,17 @@ func NewPeer(roomURL, name string, onData func([]byte)) (*Peer, error) { } return &Peer{ - roomURL: roomURL, - name: name, - conn: conn, - onData: onData, - reconnectCh: make(chan struct{}, 1), - closeCh: make(chan struct{}), - keepAliveCh: make(chan struct{}), - telemetryCh: make(chan struct{}, 1), - sendQueue: make(chan []byte, 5000), - ackWaiters: make(map[string]chan struct{}), + roomURL: roomURL, + name: name, + conn: conn, + onData: onData, + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + keepAliveCh: make(chan struct{}), + sessionCloseCh: make(chan struct{}), + telemetryCh: make(chan struct{}, 1), + sendQueue: make(chan []byte, 5000), + ackWaiters: make(map[string]chan struct{}), trafficShape: TrafficShape{ MaxMessageSize: realDataChannelMessageLimit, MinDelay: defaultSendDelayMin, @@ -112,6 +116,55 @@ func NewPeer(roomURL, name string, onData func([]byte)) (*Peer, error) { }, nil } +func closeSignal(ch chan struct{}) { + if ch == nil { + return + } + select { + case <-ch: + default: + close(ch) + } +} + +func (p *Peer) queueReconnect() { + if p.closed.Load() || p.reconnecting.Load() { + return + } + select { + case p.reconnectCh <- struct{}{}: + default: + } +} + +func (p *Peer) stopSession() { + p.stopTelemetry() + + p.sessionMu.Lock() + closeSignal(p.keepAliveCh) + closeSignal(p.sessionCloseCh) + p.sessionMu.Unlock() +} + +func (p *Peer) resetSession() (chan struct{}, chan struct{}) { + p.sessionMu.Lock() + defer p.sessionMu.Unlock() + + p.keepAliveCh = make(chan struct{}) + p.sessionCloseCh = make(chan struct{}) + return p.keepAliveCh, p.sessionCloseCh +} + +func (p *Peer) drainReconnectQueue() { + for { + select { + case <-p.reconnectCh: + default: + return + } + } +} + func (p *Peer) Connect(ctx context.Context) error { p.closed.Store(false) @@ -137,10 +190,7 @@ func (p *Peer) Connect(ctx context.Context) error { p.pcSub.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { log.Printf("Subscriber PeerConnection state: %s", state.String()) if !p.closed.Load() && (state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateDisconnected) { - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() } }) @@ -152,10 +202,7 @@ func (p *Peer) Connect(ctx context.Context) error { p.pcPub.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { log.Printf("Publisher PeerConnection state: %s", state.String()) if !p.closed.Load() && (state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateDisconnected) { - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() } }) @@ -165,6 +212,7 @@ func (p *Peer) Connect(ctx context.Context) error { } dcReady := make(chan struct{}) + keepAliveCh, sessionCloseCh := p.resetSession() p.dc.OnOpen(func() { log.Println("DataChannel opened") @@ -173,14 +221,14 @@ func (p *Peer) Connect(ctx context.Context) error { p.wg.Add(1) go func(workerID int) { defer p.wg.Done() - p.processSendQueue(workerID) + p.processSendQueue(workerID, sessionCloseCh) }(i) } p.wg.Add(1) go func() { defer p.wg.Done() - p.monitorQueue() + p.monitorQueue(sessionCloseCh) }() close(dcReady) @@ -193,10 +241,7 @@ func (p *Peer) Connect(ctx context.Context) error { p.onReconnect(nil) } if !p.closed.Load() { - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() } }) @@ -211,10 +256,7 @@ func (p *Peer) Connect(ctx context.Context) error { dc.OnClose(func() { log.Println("Received DataChannel closed - triggering reconnect") if !p.closed.Load() { - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() } }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { @@ -244,7 +286,7 @@ func (p *Peer) Connect(ctx context.Context) error { p.wg.Add(1) go func() { defer p.wg.Done() - p.keepAlive() + p.keepAlive(keepAliveCh) }() if err := p.sendHello(); err != nil { @@ -341,10 +383,7 @@ func (p *Peer) handleSignaling() { if err := p.ws.ReadJSON(&msg); err != nil { log.Printf("WS read error: %v", err) if !p.closed.Load() { - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() } return } @@ -848,7 +887,7 @@ func (p *Peer) Close() error { return nil } -func (p *Peer) keepAlive() { +func (p *Peer) keepAlive(keepAliveCh <-chan struct{}) { wsPingTicker := time.NewTicker(30 * time.Second) defer wsPingTicker.Stop() @@ -863,10 +902,7 @@ func (p *Peer) keepAlive() { if err := p.ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { log.Printf("WS Ping error: %v", err) p.wsMu.Unlock() - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() return } } @@ -880,15 +916,12 @@ func (p *Peer) keepAlive() { }); err != nil { log.Printf("App Ping error: %v", err) p.wsMu.Unlock() - select { - case p.reconnectCh <- struct{}{}: - default: - } + p.queueReconnect() return } } p.wsMu.Unlock() - case <-p.keepAliveCh: + case <-keepAliveCh: return case <-p.closeCh: return @@ -898,11 +931,13 @@ func (p *Peer) keepAlive() { func (p *Peer) reconnect(ctx context.Context) error { log.Println("Reconnecting...") + p.reconnecting.Store(true) + defer p.reconnecting.Store(false) p.sendLeave(uuid.New().String()) time.Sleep(500 * time.Millisecond) - close(p.keepAliveCh) + p.stopSession() if p.dc != nil { p.dc.Close() @@ -925,8 +960,6 @@ func (p *Peer) reconnect(ctx context.Context) error { time.Sleep(3 * time.Second) - p.keepAliveCh = make(chan struct{}) - conn, err := GetConnectionInfo(p.roomURL, p.name) if err != nil { return err @@ -941,6 +974,8 @@ func (p *Peer) reconnect(ctx context.Context) error { p.onReconnect(p.dc) } + p.drainReconnectQueue() + return nil } @@ -982,6 +1017,9 @@ func (p *Peer) WatchConnection(ctx context.Context) { time.Sleep(backoff) continue } + p.reconnectMu.Lock() + p.reconnectCount = 0 + p.reconnectMu.Unlock() log.Println("Reconnected successfully") break } @@ -993,7 +1031,7 @@ func (p *Peer) WatchConnection(ctx context.Context) { } } -func (p *Peer) processSendQueue(workerID int) { +func (p *Peer) processSendQueue(workerID int, sessionCloseCh <-chan struct{}) { log.Printf("[WORKER-%d] Started", workerID) defer log.Printf("[WORKER-%d] Stopped", workerID) @@ -1051,13 +1089,15 @@ func (p *Peer) processSendQueue(workerID int) { } } + case <-sessionCloseCh: + return case <-p.closeCh: return } } } -func (p *Peer) monitorQueue() { +func (p *Peer) monitorQueue(sessionCloseCh <-chan struct{}) { ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() @@ -1072,6 +1112,8 @@ func (p *Peer) monitorQueue() { if queueLen > 800 || buffered > 3*1024*1024 { log.Printf("[QUEUE_MONITOR] queue_len=%d dc_buffered=%d MB", queueLen, buffered/(1024*1024)) } + case <-sessionCloseCh: + return case <-p.closeCh: return }