feat: add session epoch tracking to detect peer restarts

This commit is contained in:
zarazaex69
2026-05-03 15:53:11 +03:00
parent 254613fb51
commit bca50fa7c9
4 changed files with 175 additions and 44 deletions

View File

@@ -56,8 +56,8 @@ type kcpRuntime struct {
closeOnce sync.Once
}
func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) {
c := newKCPConn(out, inboundQueueSize)
func startKCP(out chan<- []byte, onData func([]byte), epochHdr [epochHdrLen]byte) (*kcpRuntime, error) {
c := newKCPConn(out, inboundQueueSize, epochHdr)
sess, err := kcp.NewConn3(kcpConvID, fakeUDPAddr(), nil, 0, 0, c)
if err != nil {

View File

@@ -24,19 +24,25 @@ type kcpConn struct {
closed chan struct{}
closeOnce sync.Once
// epochHdr is prepended to every outgoing KCP packet so that the peer
// can detect a session restart on our side (see transport.go for the
// layout). Stable for the lifetime of this kcpConn.
epochHdr [epochHdrLen]byte
mu sync.Mutex
rDeadline time.Time
wDeadline time.Time
}
func newKCPConn(out chan<- []byte, inboundCap int) *kcpConn {
func newKCPConn(out chan<- []byte, inboundCap int, epochHdr [epochHdrLen]byte) *kcpConn {
if inboundCap <= 0 {
inboundCap = 1024
}
return &kcpConn{
out: out,
in: make(chan []byte, inboundCap),
closed: make(chan struct{}),
out: out,
in: make(chan []byte, inboundCap),
closed: make(chan struct{}),
epochHdr: epochHdr,
}
}
@@ -80,8 +86,9 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) {
}
func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) {
buf := make([]byte, len(p))
copy(buf, p)
buf := make([]byte, epochHdrLen+len(p))
copy(buf, c.epochHdr[:])
copy(buf[epochHdrLen:], p)
c.mu.Lock()
deadline := c.wDeadline

View File

@@ -2,6 +2,8 @@ package vp8channel
import (
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"sync"
@@ -18,7 +20,7 @@ import (
const (
defaultMaxPayloadSize = 60 * 1024
defaultConnectTimeout = 30 * time.Second
defaultConnectTimeout = 60 * time.Second
rtpBufSize = 65536
outboundQueueSize = 1024
inboundQueueSize = 1024
@@ -40,11 +42,22 @@ var vp8Keepalive = []byte{
0x99, 0x84, 0x88, 0xfc,
}
// kcpMagic is the little-endian first byte of a KCP packet (low byte of
// kcpConvID = 0xC0FFEE01). Anything that does not match is treated as
// non-KCP traffic (idle keepalives, stray frames after reconnect) and
// dropped before reaching the protocol stack.
const kcpMagic = byte(0x01)
// kcpFrameMagic marks a VP8 frame as carrying a KCP segment with our
// session-epoch header. The wire layout inside the VP8 frame is:
//
// [0] = kcpFrameMagic (0x4B = 'K')
// [1..5] = sender's session epoch (big-endian uint32)
// [5..] = raw KCP packet bytes
//
// The epoch lets a receiver detect that the peer has restarted its KCP
// session - typical when the SFU keeps forwarding the same remote video
// track across our process restarts, so handleRemoteTrack never fires
// again. On any epoch change we reset the local KCP session so both ends
// converge on fresh state.
const (
kcpFrameMagic = byte(0x4B)
epochHdrLen = 5
)
type streamTransport struct {
stream carrier.VideoTrack
@@ -55,12 +68,22 @@ type streamTransport struct {
writerDone chan struct{}
closed atomic.Bool
writerUp atomic.Bool
startOnce sync.Once
writerOnce sync.Once
kcpOnce sync.Once
frameInterval time.Duration
batchSize int
kcp *kcpRuntime
kcpMu sync.RWMutex
// localEpoch is bumped on every KCP session restart and stamped into
// every outgoing VP8 frame. peerEpoch tracks the last epoch we observed
// from the remote so we can detect their restart and reset locally.
localEpoch uint32
peerEpoch atomic.Uint32
hadPeer atomic.Bool
kcp *kcpRuntime
kcpMu sync.RWMutex
reconnectMu sync.Mutex
reconnectFn func()
}
// New creates a vp8channel transport backed by a carrier-specific provider.
@@ -111,6 +134,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
writerDone: make(chan struct{}),
frameInterval: time.Second / time.Duration(fps),
batchSize: batchSize,
localEpoch: randomEpoch(),
}
if err := stream.AddTrack(track); err != nil {
@@ -129,12 +153,19 @@ func (p *streamTransport) Connect(ctx context.Context) error {
return fmt.Errorf("connect stream: %w", err)
}
p.writerOnce.Do(func() {
p.writerUp.Store(true)
go p.writerLoop()
})
// Start KCP immediately. Don't wait for the peer's video track:
// the server may legitimately come up before any client joins the
// room, and KCP itself does not need a handshake. Once the peer
// shows up, handleRemoteTrack starts pumping their RTP into our
// session and the epoch-change detector handles peer restarts.
var startErr error
p.startOnce.Do(func() {
// Start KCP first so the writerLoop has packets to forward as soon
// as it begins ticking. KCP's own update goroutine drives keepalives
// and ACKs once the session is up.
rt, err := startKCP(p.outbound, p.onData)
p.kcpOnce.Do(func() {
rt, err := startKCP(p.outbound, p.onData, p.epochHeader())
if err != nil {
startErr = err
return
@@ -142,14 +173,35 @@ func (p *streamTransport) Connect(ctx context.Context) error {
p.kcpMu.Lock()
p.kcp = rt
p.kcpMu.Unlock()
p.writerUp.Store(true)
go p.writerLoop()
})
return startErr
}
// epochHeader returns the 5-byte VP8-frame header used to tag every KCP
// packet sent in the current local session.
func (p *streamTransport) epochHeader() [epochHdrLen]byte {
var hdr [epochHdrLen]byte
hdr[0] = kcpFrameMagic
binary.BigEndian.PutUint32(hdr[1:], p.localEpoch)
return hdr
}
func randomEpoch() uint32 {
var b [4]byte
if _, err := rand.Read(b[:]); err != nil {
// rand.Read on Linux essentially never fails; fall back to a
// time-derived value rather than panic.
//nolint:gosec // intentional uint32 truncation of a nanosecond timestamp
return uint32(time.Now().UnixNano())
}
e := binary.BigEndian.Uint32(b[:])
if e == 0 {
e = 1
}
return e
}
func (p *streamTransport) Send(data []byte) error {
if p.closed.Load() {
return ErrTransportClosed
@@ -197,12 +249,11 @@ func (p *streamTransport) drainOutbound() {
}
func (p *streamTransport) SetReconnectCallback(cb func()) {
p.reconnectMu.Lock()
p.reconnectFn = cb
p.reconnectMu.Unlock()
p.stream.SetReconnectCallback(func() {
// Drain stale KCP segments queued for the old wire. KCP will
// retransmit anything that mattered after the link is back up,
// so dropping the queue here only saves us from sending obsolete
// data that the peer would discard anyway.
p.drainOutbound()
p.resetKCP()
if cb != nil {
cb()
}
@@ -286,12 +337,38 @@ func (p *streamTransport) writerLoop() {
}
}
func (p *streamTransport) resetKCP() {
p.drainOutbound()
p.kcpMu.Lock()
old := p.kcp
p.kcp = nil
p.kcpMu.Unlock()
if old != nil {
old.close()
}
// Note: localEpoch is intentionally NOT bumped here. The epoch is a
// per-process identifier set once in New(). If we changed it on every
// peer-triggered reset, the peer would see a "new" epoch from us, reset
// itself, send back its (unchanged) epoch which we'd then see as "new"
// again - and the two sides would loop forever tearing down smux.
rt, err := startKCP(p.outbound, p.onData, p.epochHeader())
if err != nil {
return
}
p.kcpMu.Lock()
p.kcp = rt
p.kcpMu.Unlock()
}
func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
if track.Codec().MimeType != webrtc.MimeTypeVP8 {
go p.drainTrack(track)
return
}
// We don't reset KCP here. Peer restarts are detected by the epoch
// header on incoming frames, which works even when the SFU keeps
// forwarding the same track across our restarts.
go p.readVP8Track(track)
}
@@ -312,8 +389,9 @@ type vp8FrameState struct {
frameValid bool
}
// processRTPPacket returns a complete KCP frame when the VP8 frame is fully assembled, nil otherwise.
// Detects packet loss/reordering to avoid silently corrupting fragmented VP8 frames.
// processRTPPacket returns a complete VP8 frame payload when fully assembled,
// nil otherwise. Detects packet loss/reordering to avoid silently corrupting
// fragmented VP8 frames.
func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte {
if s.haveLastSeq && pkt.SequenceNumber != s.lastSeq+1 {
s.frameValid = false
@@ -349,7 +427,7 @@ func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte {
s.frameValid = false
}()
if len(s.frameBuf) >= 4 && s.frameBuf[0] == kcpMagic {
if len(s.frameBuf) >= epochHdrLen && s.frameBuf[0] == kcpFrameMagic {
frame := make([]byte, len(s.frameBuf))
copy(frame, s.frameBuf)
return frame
@@ -377,11 +455,43 @@ func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
continue
}
p.kcpMu.RLock()
rt := p.kcp
p.kcpMu.RUnlock()
if rt != nil {
rt.deliver(frame)
}
p.handleIncomingFrame(frame)
}
}
// handleIncomingFrame parses the epoch header and either delivers the KCP
// payload to the local session or triggers a reset when the peer's epoch
// changes (peer process restart).
func (p *streamTransport) handleIncomingFrame(frame []byte) {
peerEpoch := binary.BigEndian.Uint32(frame[1:epochHdrLen])
kcpPayload := frame[epochHdrLen:]
if len(kcpPayload) == 0 {
return
}
if !p.hadPeer.Swap(true) {
p.peerEpoch.Store(peerEpoch)
} else if prev := p.peerEpoch.Load(); prev != peerEpoch {
// Peer restarted its KCP session. Reset ours so the conv state
// machines re-converge. CAS guards against double-reset when
// fragmented frames straddle the epoch boundary.
if p.peerEpoch.CompareAndSwap(prev, peerEpoch) {
p.resetKCP()
p.reconnectMu.Lock()
fn := p.reconnectFn
p.reconnectMu.Unlock()
if fn != nil {
fn()
}
}
// Drop this packet: it predates our fresh KCP session.
return
}
p.kcpMu.RLock()
rt := p.kcp
p.kcpMu.RUnlock()
if rt != nil {
rt.deliver(kcpPayload)
}
}

View File

@@ -13,7 +13,11 @@ func pumpPackets(stop <-chan struct{}, from <-chan []byte, to *kcpRuntime) {
case <-stop:
return
case pkt := <-from:
to.deliver(pkt)
// Strip the on-wire epoch header that kcpConn prepends;
// the real receive path does this before calling deliver().
if len(pkt) > epochHdrLen {
to.deliver(pkt[epochHdrLen:])
}
}
}
}
@@ -66,13 +70,13 @@ func TestKCPLoopback(t *testing.T) {
cb, doneB, getRecv := buildReceiver(len(msgs))
rtA, err := startKCP(a2b, nil)
rtA, err := startKCP(a2b, nil, testEpochHdr(1))
if err != nil {
t.Fatalf("startKCP A: %v", err)
}
defer rtA.close()
rtB, err := startKCP(b2a, cb)
rtB, err := startKCP(b2a, cb, testEpochHdr(2))
if err != nil {
t.Fatalf("startKCP B: %v", err)
}
@@ -100,7 +104,17 @@ func TestKCPLoopback(t *testing.T) {
}
func TestVP8KeepaliveDoesNotLookLikeKCP(t *testing.T) {
if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpMagic {
t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpMagic)
if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpFrameMagic {
t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpFrameMagic)
}
}
func testEpochHdr(epoch uint32) [epochHdrLen]byte {
var hdr [epochHdrLen]byte
hdr[0] = kcpFrameMagic
hdr[1] = byte(epoch >> 24)
hdr[2] = byte(epoch >> 16) //nolint:gosec
hdr[3] = byte(epoch >> 8) //nolint:gosec
hdr[4] = byte(epoch) //nolint:gosec
return hdr
}