From bca50fa7c9a117cc83e4630a2a1a7172e22491d4 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Sun, 3 May 2026 15:53:11 +0300 Subject: [PATCH] feat: add session epoch tracking to detect peer restarts --- internal/transport/vp8channel/kcp.go | 4 +- internal/transport/vp8channel/kcpconn.go | 19 +- internal/transport/vp8channel/transport.go | 172 ++++++++++++++---- .../transport/vp8channel/transport_test.go | 24 ++- 4 files changed, 175 insertions(+), 44 deletions(-) diff --git a/internal/transport/vp8channel/kcp.go b/internal/transport/vp8channel/kcp.go index 570f60b..f3988ef 100644 --- a/internal/transport/vp8channel/kcp.go +++ b/internal/transport/vp8channel/kcp.go @@ -56,8 +56,8 @@ type kcpRuntime struct { closeOnce sync.Once } -func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) { - c := newKCPConn(out, inboundQueueSize) +func startKCP(out chan<- []byte, onData func([]byte), epochHdr [epochHdrLen]byte) (*kcpRuntime, error) { + c := newKCPConn(out, inboundQueueSize, epochHdr) sess, err := kcp.NewConn3(kcpConvID, fakeUDPAddr(), nil, 0, 0, c) if err != nil { diff --git a/internal/transport/vp8channel/kcpconn.go b/internal/transport/vp8channel/kcpconn.go index 161f9d5..58e61d0 100644 --- a/internal/transport/vp8channel/kcpconn.go +++ b/internal/transport/vp8channel/kcpconn.go @@ -24,19 +24,25 @@ type kcpConn struct { closed chan struct{} closeOnce sync.Once + // epochHdr is prepended to every outgoing KCP packet so that the peer + // can detect a session restart on our side (see transport.go for the + // layout). Stable for the lifetime of this kcpConn. + epochHdr [epochHdrLen]byte + mu sync.Mutex rDeadline time.Time wDeadline time.Time } -func newKCPConn(out chan<- []byte, inboundCap int) *kcpConn { +func newKCPConn(out chan<- []byte, inboundCap int, epochHdr [epochHdrLen]byte) *kcpConn { if inboundCap <= 0 { inboundCap = 1024 } return &kcpConn{ - out: out, - in: make(chan []byte, inboundCap), - closed: make(chan struct{}), + out: out, + in: make(chan []byte, inboundCap), + closed: make(chan struct{}), + epochHdr: epochHdr, } } @@ -80,8 +86,9 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) { } func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) { - buf := make([]byte, len(p)) - copy(buf, p) + buf := make([]byte, epochHdrLen+len(p)) + copy(buf, c.epochHdr[:]) + copy(buf[epochHdrLen:], p) c.mu.Lock() deadline := c.wDeadline diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index a10c4ca..294a901 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -2,6 +2,8 @@ package vp8channel import ( "context" + "crypto/rand" + "encoding/binary" "errors" "fmt" "sync" @@ -18,7 +20,7 @@ import ( const ( defaultMaxPayloadSize = 60 * 1024 - defaultConnectTimeout = 30 * time.Second + defaultConnectTimeout = 60 * time.Second rtpBufSize = 65536 outboundQueueSize = 1024 inboundQueueSize = 1024 @@ -40,11 +42,22 @@ var vp8Keepalive = []byte{ 0x99, 0x84, 0x88, 0xfc, } -// kcpMagic is the little-endian first byte of a KCP packet (low byte of -// kcpConvID = 0xC0FFEE01). Anything that does not match is treated as -// non-KCP traffic (idle keepalives, stray frames after reconnect) and -// dropped before reaching the protocol stack. -const kcpMagic = byte(0x01) +// kcpFrameMagic marks a VP8 frame as carrying a KCP segment with our +// session-epoch header. The wire layout inside the VP8 frame is: +// +// [0] = kcpFrameMagic (0x4B = 'K') +// [1..5] = sender's session epoch (big-endian uint32) +// [5..] = raw KCP packet bytes +// +// The epoch lets a receiver detect that the peer has restarted its KCP +// session - typical when the SFU keeps forwarding the same remote video +// track across our process restarts, so handleRemoteTrack never fires +// again. On any epoch change we reset the local KCP session so both ends +// converge on fresh state. +const ( + kcpFrameMagic = byte(0x4B) + epochHdrLen = 5 +) type streamTransport struct { stream carrier.VideoTrack @@ -55,12 +68,22 @@ type streamTransport struct { writerDone chan struct{} closed atomic.Bool writerUp atomic.Bool - startOnce sync.Once + writerOnce sync.Once + kcpOnce sync.Once frameInterval time.Duration batchSize int - kcp *kcpRuntime - kcpMu sync.RWMutex + // localEpoch is bumped on every KCP session restart and stamped into + // every outgoing VP8 frame. peerEpoch tracks the last epoch we observed + // from the remote so we can detect their restart and reset locally. + localEpoch uint32 + peerEpoch atomic.Uint32 + hadPeer atomic.Bool + + kcp *kcpRuntime + kcpMu sync.RWMutex + reconnectMu sync.Mutex + reconnectFn func() } // New creates a vp8channel transport backed by a carrier-specific provider. @@ -111,6 +134,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) writerDone: make(chan struct{}), frameInterval: time.Second / time.Duration(fps), batchSize: batchSize, + localEpoch: randomEpoch(), } if err := stream.AddTrack(track); err != nil { @@ -129,12 +153,19 @@ func (p *streamTransport) Connect(ctx context.Context) error { return fmt.Errorf("connect stream: %w", err) } + p.writerOnce.Do(func() { + p.writerUp.Store(true) + go p.writerLoop() + }) + + // Start KCP immediately. Don't wait for the peer's video track: + // the server may legitimately come up before any client joins the + // room, and KCP itself does not need a handshake. Once the peer + // shows up, handleRemoteTrack starts pumping their RTP into our + // session and the epoch-change detector handles peer restarts. var startErr error - p.startOnce.Do(func() { - // Start KCP first so the writerLoop has packets to forward as soon - // as it begins ticking. KCP's own update goroutine drives keepalives - // and ACKs once the session is up. - rt, err := startKCP(p.outbound, p.onData) + p.kcpOnce.Do(func() { + rt, err := startKCP(p.outbound, p.onData, p.epochHeader()) if err != nil { startErr = err return @@ -142,14 +173,35 @@ func (p *streamTransport) Connect(ctx context.Context) error { p.kcpMu.Lock() p.kcp = rt p.kcpMu.Unlock() - - p.writerUp.Store(true) - go p.writerLoop() }) return startErr } +// epochHeader returns the 5-byte VP8-frame header used to tag every KCP +// packet sent in the current local session. +func (p *streamTransport) epochHeader() [epochHdrLen]byte { + var hdr [epochHdrLen]byte + hdr[0] = kcpFrameMagic + binary.BigEndian.PutUint32(hdr[1:], p.localEpoch) + return hdr +} + +func randomEpoch() uint32 { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + // rand.Read on Linux essentially never fails; fall back to a + // time-derived value rather than panic. + //nolint:gosec // intentional uint32 truncation of a nanosecond timestamp + return uint32(time.Now().UnixNano()) + } + e := binary.BigEndian.Uint32(b[:]) + if e == 0 { + e = 1 + } + return e +} + func (p *streamTransport) Send(data []byte) error { if p.closed.Load() { return ErrTransportClosed @@ -197,12 +249,11 @@ func (p *streamTransport) drainOutbound() { } func (p *streamTransport) SetReconnectCallback(cb func()) { + p.reconnectMu.Lock() + p.reconnectFn = cb + p.reconnectMu.Unlock() p.stream.SetReconnectCallback(func() { - // Drain stale KCP segments queued for the old wire. KCP will - // retransmit anything that mattered after the link is back up, - // so dropping the queue here only saves us from sending obsolete - // data that the peer would discard anyway. - p.drainOutbound() + p.resetKCP() if cb != nil { cb() } @@ -286,12 +337,38 @@ func (p *streamTransport) writerLoop() { } } +func (p *streamTransport) resetKCP() { + p.drainOutbound() + p.kcpMu.Lock() + old := p.kcp + p.kcp = nil + p.kcpMu.Unlock() + if old != nil { + old.close() + } + // Note: localEpoch is intentionally NOT bumped here. The epoch is a + // per-process identifier set once in New(). If we changed it on every + // peer-triggered reset, the peer would see a "new" epoch from us, reset + // itself, send back its (unchanged) epoch which we'd then see as "new" + // again - and the two sides would loop forever tearing down smux. + rt, err := startKCP(p.outbound, p.onData, p.epochHeader()) + if err != nil { + return + } + p.kcpMu.Lock() + p.kcp = rt + p.kcpMu.Unlock() +} + func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { if track.Codec().MimeType != webrtc.MimeTypeVP8 { go p.drainTrack(track) return } + // We don't reset KCP here. Peer restarts are detected by the epoch + // header on incoming frames, which works even when the SFU keeps + // forwarding the same track across our restarts. go p.readVP8Track(track) } @@ -312,8 +389,9 @@ type vp8FrameState struct { frameValid bool } -// processRTPPacket returns a complete KCP frame when the VP8 frame is fully assembled, nil otherwise. -// Detects packet loss/reordering to avoid silently corrupting fragmented VP8 frames. +// processRTPPacket returns a complete VP8 frame payload when fully assembled, +// nil otherwise. Detects packet loss/reordering to avoid silently corrupting +// fragmented VP8 frames. func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte { if s.haveLastSeq && pkt.SequenceNumber != s.lastSeq+1 { s.frameValid = false @@ -349,7 +427,7 @@ func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte { s.frameValid = false }() - if len(s.frameBuf) >= 4 && s.frameBuf[0] == kcpMagic { + if len(s.frameBuf) >= epochHdrLen && s.frameBuf[0] == kcpFrameMagic { frame := make([]byte, len(s.frameBuf)) copy(frame, s.frameBuf) return frame @@ -377,11 +455,43 @@ func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) { continue } - p.kcpMu.RLock() - rt := p.kcp - p.kcpMu.RUnlock() - if rt != nil { - rt.deliver(frame) - } + p.handleIncomingFrame(frame) + } +} + +// handleIncomingFrame parses the epoch header and either delivers the KCP +// payload to the local session or triggers a reset when the peer's epoch +// changes (peer process restart). +func (p *streamTransport) handleIncomingFrame(frame []byte) { + peerEpoch := binary.BigEndian.Uint32(frame[1:epochHdrLen]) + kcpPayload := frame[epochHdrLen:] + if len(kcpPayload) == 0 { + return + } + + if !p.hadPeer.Swap(true) { + p.peerEpoch.Store(peerEpoch) + } else if prev := p.peerEpoch.Load(); prev != peerEpoch { + // Peer restarted its KCP session. Reset ours so the conv state + // machines re-converge. CAS guards against double-reset when + // fragmented frames straddle the epoch boundary. + if p.peerEpoch.CompareAndSwap(prev, peerEpoch) { + p.resetKCP() + p.reconnectMu.Lock() + fn := p.reconnectFn + p.reconnectMu.Unlock() + if fn != nil { + fn() + } + } + // Drop this packet: it predates our fresh KCP session. + return + } + + p.kcpMu.RLock() + rt := p.kcp + p.kcpMu.RUnlock() + if rt != nil { + rt.deliver(kcpPayload) } } diff --git a/internal/transport/vp8channel/transport_test.go b/internal/transport/vp8channel/transport_test.go index 33feb2b..279ef1b 100644 --- a/internal/transport/vp8channel/transport_test.go +++ b/internal/transport/vp8channel/transport_test.go @@ -13,7 +13,11 @@ func pumpPackets(stop <-chan struct{}, from <-chan []byte, to *kcpRuntime) { case <-stop: return case pkt := <-from: - to.deliver(pkt) + // Strip the on-wire epoch header that kcpConn prepends; + // the real receive path does this before calling deliver(). + if len(pkt) > epochHdrLen { + to.deliver(pkt[epochHdrLen:]) + } } } } @@ -66,13 +70,13 @@ func TestKCPLoopback(t *testing.T) { cb, doneB, getRecv := buildReceiver(len(msgs)) - rtA, err := startKCP(a2b, nil) + rtA, err := startKCP(a2b, nil, testEpochHdr(1)) if err != nil { t.Fatalf("startKCP A: %v", err) } defer rtA.close() - rtB, err := startKCP(b2a, cb) + rtB, err := startKCP(b2a, cb, testEpochHdr(2)) if err != nil { t.Fatalf("startKCP B: %v", err) } @@ -100,7 +104,17 @@ func TestKCPLoopback(t *testing.T) { } func TestVP8KeepaliveDoesNotLookLikeKCP(t *testing.T) { - if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpMagic { - t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpMagic) + if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpFrameMagic { + t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpFrameMagic) } } + +func testEpochHdr(epoch uint32) [epochHdrLen]byte { + var hdr [epochHdrLen]byte + hdr[0] = kcpFrameMagic + hdr[1] = byte(epoch >> 24) + hdr[2] = byte(epoch >> 16) //nolint:gosec + hdr[3] = byte(epoch >> 8) //nolint:gosec + hdr[4] = byte(epoch) //nolint:gosec + return hdr +}