refactor(vp8channel): add ResetPeer with epoch rotation and mutex #60

This commit is contained in:
zarazaex69
2026-05-19 21:58:12 +03:00
parent 085aadcad7
commit 2b6f77f0f6
2 changed files with 97 additions and 18 deletions

View File

@@ -96,10 +96,11 @@ type streamTransport struct {
frameInterval time.Duration
batchSize int
// localEpoch is bumped on every KCP session restart and stamped into
// every outgoing VP8 frame. peerEpoch tracks the last epoch we observed
// from the remote so we can detect their restart and reset locally.
// localEpoch is stamped into every outgoing VP8 frame. Explicit
// upper-layer resets rotate it so the peer can reset its KCP state too.
// Peer-triggered resets keep it stable to avoid reset ping-pong.
bindingToken uint32
epochMu sync.RWMutex
localEpoch uint32
peerEpoch atomic.Uint32
hadPeer atomic.Bool
@@ -204,7 +205,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
p.kcpMu.Lock()
p.kcp = rt
p.kcpMu.Unlock()
logger.Infof("vp8channel: KCP started localEpoch=0x%08x", p.localEpoch)
logger.Infof("vp8channel: KCP started localEpoch=0x%08x", p.localEpochValue())
})
p.writerOnce.Do(func() {
@@ -218,14 +219,41 @@ func (p *streamTransport) Connect(ctx context.Context) error {
// epochHeader returns the 5-byte VP8-frame header used to tag every KCP
// packet sent in the current local session.
func (p *streamTransport) epochHeader() [epochHdrLen]byte {
p.epochMu.RLock()
epoch := p.localEpoch
p.epochMu.RUnlock()
return buildEpochHeader(p.bindingToken, epoch)
}
func buildEpochHeader(token, epoch uint32) [epochHdrLen]byte {
var hdr [epochHdrLen]byte
copy(hdr[:], vp8Keepalive)
binary.BigEndian.PutUint32(hdr[tokenOff:epochOff], p.bindingToken)
binary.BigEndian.PutUint32(hdr[epochOff:crcOff], p.localEpoch)
binary.BigEndian.PutUint32(hdr[crcOff:epochHdrLen], epochCRC(p.bindingToken, p.localEpoch))
binary.BigEndian.PutUint32(hdr[tokenOff:epochOff], token)
binary.BigEndian.PutUint32(hdr[epochOff:crcOff], epoch)
binary.BigEndian.PutUint32(hdr[crcOff:epochHdrLen], epochCRC(token, epoch))
return hdr
}
func (p *streamTransport) rotateEpochHeader() [epochHdrLen]byte {
p.epochMu.Lock()
for {
next := randomEpoch()
if next != p.localEpoch {
p.localEpoch = next
break
}
}
epoch := p.localEpoch
p.epochMu.Unlock()
return buildEpochHeader(p.bindingToken, epoch)
}
func (p *streamTransport) localEpochValue() uint32 {
p.epochMu.RLock()
defer p.epochMu.RUnlock()
return p.localEpoch
}
func epochCRC(token, epoch uint32) uint32 {
var buf [8]byte
binary.BigEndian.PutUint32(buf[0:4], token)
@@ -313,6 +341,14 @@ func (p *streamTransport) drainOutbound() {
}
}
// ResetPeer drops queued KCP traffic and starts a fresh KCP state machine while
// keeping the carrier connection alive. The client/server liveness layer calls
// this before rebuilding smux so replacement handshakes are not parsed behind
// stale bytes from streams that were active when the old session died.
func (p *streamTransport) ResetPeer() {
p.restartKCP(p.rotateEpochHeader())
}
func (p *streamTransport) SetReconnectCallback(cb func()) {
p.reconnectMu.Lock()
p.reconnectFn = cb
@@ -407,6 +443,10 @@ func (p *streamTransport) sampleInterval() time.Duration {
}
func (p *streamTransport) resetKCP() {
p.restartKCP(p.epochHeader())
}
func (p *streamTransport) restartKCP(epochHdr [epochHdrLen]byte) {
p.drainOutbound()
p.kcpMu.Lock()
old := p.kcp
@@ -415,12 +455,7 @@ func (p *streamTransport) resetKCP() {
if old != nil {
old.close()
}
// Note: localEpoch is intentionally NOT bumped here. The epoch is a
// per-process identifier set once in New(). If we changed it on every
// peer-triggered reset, the peer would see a "new" epoch from us, reset
// itself, send back its (unchanged) epoch which we'd then see as "new"
// again - and the two sides would loop forever tearing down smux.
rt, err := startKCP(p.outbound, p.onData, p.epochHeader())
rt, err := startKCP(p.outbound, p.onData, epochHdr)
if err != nil {
return
}
@@ -552,7 +587,7 @@ func (p *streamTransport) handleIncomingFrame(frame []byte) {
// remote track. Those frames carry our local epoch, not the peer's. If we
// treat them as peer traffic, epoch tracking toggles between "self" and
// "peer" and both sides loop forever resetting smux/KCP.
if peerEpoch == p.localEpoch {
if peerEpoch == p.localEpochValue() {
logger.Debugf("vp8channel: self-echo detected epoch=0x%08x (SFU reflects our own track)", peerEpoch)
return
}

View File

@@ -90,10 +90,10 @@ func (s *fakeEngineSession) SetEndedCallback(cb func(string)) { s.stream.SetEnd
func (s *fakeEngineSession) WatchConnection(ctx context.Context) {
s.stream.WatchConnection(ctx)
}
func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() }
func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil }
func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 }
func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) }
func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() }
func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil }
func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 }
func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) }
func (s *fakeEngineSession) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) {
s.stream.SetTrackHandler(cb)
}
@@ -230,6 +230,50 @@ func TestEpochHeaderTokenAndOutboundCapacity(t *testing.T) {
}
}
func TestResetPeerRestartsKCPAndDrainsOutbound(t *testing.T) {
tr := &streamTransport{
stream: &fakeVideoStream{canSend: true},
outbound: make(chan []byte, 10),
closeCh: make(chan struct{}),
writerDone: make(chan struct{}),
bindingToken: bindingToken("client"),
localEpoch: 0x01020304,
}
defer func() {
_ = tr.Close()
}()
rt, err := startKCP(tr.outbound, nil, tr.epochHeader())
if err != nil {
t.Fatalf("startKCP: %v", err)
}
tr.kcpMu.Lock()
tr.kcp = rt
tr.kcpMu.Unlock()
tr.outbound <- []byte("stale")
oldEpoch := tr.localEpoch
tr.ResetPeer()
tr.kcpMu.RLock()
got := tr.kcp
tr.kcpMu.RUnlock()
if got == nil || got == rt {
t.Fatalf("ResetPeer kcp = %p, want fresh non-nil runtime distinct from %p", got, rt)
}
if len(tr.outbound) != 0 {
t.Fatalf("ResetPeer left %d outbound frame(s), want 0", len(tr.outbound))
}
if tr.localEpoch == oldEpoch {
t.Fatalf("ResetPeer localEpoch = %#x, want different epoch", tr.localEpoch)
}
select {
case <-rt.readDone:
case <-time.After(time.Second):
t.Fatal("old KCP runtime did not stop")
}
}
func TestVP8FrameStateAssemblesAndRejectsCorruptFrames(t *testing.T) {
frame := append(append([]byte(nil), vp8Keepalive...), bytes.Repeat([]byte{0x01}, epochHdrLen-len(vp8Keepalive))...)
var state vp8FrameState