diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index ed36bed..e549895 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -96,10 +96,11 @@ type streamTransport struct { frameInterval time.Duration batchSize int - // localEpoch is bumped on every KCP session restart and stamped into - // every outgoing VP8 frame. peerEpoch tracks the last epoch we observed - // from the remote so we can detect their restart and reset locally. + // localEpoch is stamped into every outgoing VP8 frame. Explicit + // upper-layer resets rotate it so the peer can reset its KCP state too. + // Peer-triggered resets keep it stable to avoid reset ping-pong. bindingToken uint32 + epochMu sync.RWMutex localEpoch uint32 peerEpoch atomic.Uint32 hadPeer atomic.Bool @@ -204,7 +205,7 @@ func (p *streamTransport) Connect(ctx context.Context) error { p.kcpMu.Lock() p.kcp = rt p.kcpMu.Unlock() - logger.Infof("vp8channel: KCP started localEpoch=0x%08x", p.localEpoch) + logger.Infof("vp8channel: KCP started localEpoch=0x%08x", p.localEpochValue()) }) p.writerOnce.Do(func() { @@ -218,14 +219,41 @@ func (p *streamTransport) Connect(ctx context.Context) error { // epochHeader returns the 5-byte VP8-frame header used to tag every KCP // packet sent in the current local session. func (p *streamTransport) epochHeader() [epochHdrLen]byte { + p.epochMu.RLock() + epoch := p.localEpoch + p.epochMu.RUnlock() + return buildEpochHeader(p.bindingToken, epoch) +} + +func buildEpochHeader(token, epoch uint32) [epochHdrLen]byte { var hdr [epochHdrLen]byte copy(hdr[:], vp8Keepalive) - binary.BigEndian.PutUint32(hdr[tokenOff:epochOff], p.bindingToken) - binary.BigEndian.PutUint32(hdr[epochOff:crcOff], p.localEpoch) - binary.BigEndian.PutUint32(hdr[crcOff:epochHdrLen], epochCRC(p.bindingToken, p.localEpoch)) + binary.BigEndian.PutUint32(hdr[tokenOff:epochOff], token) + binary.BigEndian.PutUint32(hdr[epochOff:crcOff], epoch) + binary.BigEndian.PutUint32(hdr[crcOff:epochHdrLen], epochCRC(token, epoch)) return hdr } +func (p *streamTransport) rotateEpochHeader() [epochHdrLen]byte { + p.epochMu.Lock() + for { + next := randomEpoch() + if next != p.localEpoch { + p.localEpoch = next + break + } + } + epoch := p.localEpoch + p.epochMu.Unlock() + return buildEpochHeader(p.bindingToken, epoch) +} + +func (p *streamTransport) localEpochValue() uint32 { + p.epochMu.RLock() + defer p.epochMu.RUnlock() + return p.localEpoch +} + func epochCRC(token, epoch uint32) uint32 { var buf [8]byte binary.BigEndian.PutUint32(buf[0:4], token) @@ -313,6 +341,14 @@ func (p *streamTransport) drainOutbound() { } } +// ResetPeer drops queued KCP traffic and starts a fresh KCP state machine while +// keeping the carrier connection alive. The client/server liveness layer calls +// this before rebuilding smux so replacement handshakes are not parsed behind +// stale bytes from streams that were active when the old session died. +func (p *streamTransport) ResetPeer() { + p.restartKCP(p.rotateEpochHeader()) +} + func (p *streamTransport) SetReconnectCallback(cb func()) { p.reconnectMu.Lock() p.reconnectFn = cb @@ -407,6 +443,10 @@ func (p *streamTransport) sampleInterval() time.Duration { } func (p *streamTransport) resetKCP() { + p.restartKCP(p.epochHeader()) +} + +func (p *streamTransport) restartKCP(epochHdr [epochHdrLen]byte) { p.drainOutbound() p.kcpMu.Lock() old := p.kcp @@ -415,12 +455,7 @@ func (p *streamTransport) resetKCP() { if old != nil { old.close() } - // Note: localEpoch is intentionally NOT bumped here. The epoch is a - // per-process identifier set once in New(). If we changed it on every - // peer-triggered reset, the peer would see a "new" epoch from us, reset - // itself, send back its (unchanged) epoch which we'd then see as "new" - // again - and the two sides would loop forever tearing down smux. - rt, err := startKCP(p.outbound, p.onData, p.epochHeader()) + rt, err := startKCP(p.outbound, p.onData, epochHdr) if err != nil { return } @@ -552,7 +587,7 @@ func (p *streamTransport) handleIncomingFrame(frame []byte) { // remote track. Those frames carry our local epoch, not the peer's. If we // treat them as peer traffic, epoch tracking toggles between "self" and // "peer" and both sides loop forever resetting smux/KCP. - if peerEpoch == p.localEpoch { + if peerEpoch == p.localEpochValue() { logger.Debugf("vp8channel: self-echo detected epoch=0x%08x (SFU reflects our own track)", peerEpoch) return } diff --git a/internal/transport/vp8channel/transport_unit_test.go b/internal/transport/vp8channel/transport_unit_test.go index 6cd97a5..98ce099 100644 --- a/internal/transport/vp8channel/transport_unit_test.go +++ b/internal/transport/vp8channel/transport_unit_test.go @@ -90,10 +90,10 @@ func (s *fakeEngineSession) SetEndedCallback(cb func(string)) { s.stream.SetEnd func (s *fakeEngineSession) WatchConnection(ctx context.Context) { s.stream.WatchConnection(ctx) } -func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() } -func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil } -func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 } -func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) } +func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() } +func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil } +func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 } +func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) } func (s *fakeEngineSession) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { s.stream.SetTrackHandler(cb) } @@ -230,6 +230,50 @@ func TestEpochHeaderTokenAndOutboundCapacity(t *testing.T) { } } +func TestResetPeerRestartsKCPAndDrainsOutbound(t *testing.T) { + tr := &streamTransport{ + stream: &fakeVideoStream{canSend: true}, + outbound: make(chan []byte, 10), + closeCh: make(chan struct{}), + writerDone: make(chan struct{}), + bindingToken: bindingToken("client"), + localEpoch: 0x01020304, + } + defer func() { + _ = tr.Close() + }() + + rt, err := startKCP(tr.outbound, nil, tr.epochHeader()) + if err != nil { + t.Fatalf("startKCP: %v", err) + } + tr.kcpMu.Lock() + tr.kcp = rt + tr.kcpMu.Unlock() + tr.outbound <- []byte("stale") + oldEpoch := tr.localEpoch + + tr.ResetPeer() + + tr.kcpMu.RLock() + got := tr.kcp + tr.kcpMu.RUnlock() + if got == nil || got == rt { + t.Fatalf("ResetPeer kcp = %p, want fresh non-nil runtime distinct from %p", got, rt) + } + if len(tr.outbound) != 0 { + t.Fatalf("ResetPeer left %d outbound frame(s), want 0", len(tr.outbound)) + } + if tr.localEpoch == oldEpoch { + t.Fatalf("ResetPeer localEpoch = %#x, want different epoch", tr.localEpoch) + } + select { + case <-rt.readDone: + case <-time.After(time.Second): + t.Fatal("old KCP runtime did not stop") + } +} + func TestVP8FrameStateAssemblesAndRejectsCorruptFrames(t *testing.T) { frame := append(append([]byte(nil), vp8Keepalive...), bytes.Repeat([]byte{0x01}, epochHdrLen-len(vp8Keepalive))...) var state vp8FrameState