diff --git a/internal/muxconn/conn.go b/internal/muxconn/conn.go index f2d3856..b2500ad 100644 --- a/internal/muxconn/conn.go +++ b/internal/muxconn/conn.go @@ -7,8 +7,9 @@ // on the peer. smux operates on a pure byte stream (header + payload may be // glued or split across reads). We bridge by: // -// - Treating each Push as an opaque chunk appended to an internal byte -// buffer that Read drains in arbitrary slices. +// - Treating each Push as an opaque chunk handed off via a channel that +// Read drains in arbitrary slices, retaining any tail bytes that did +// not fit the caller's buffer for the next Read. // - Letting smux's sendLoop call Write once per frame; we encrypt and hand // the whole buffer to the link as a single message. Length boundaries // are preserved end-to-end by the transport (KCP length-prefix framing @@ -21,6 +22,7 @@ import ( "io" "runtime" "sync" + "sync/atomic" "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" @@ -31,80 +33,131 @@ import ( // ErrClosed is returned from Read/Write after the conn has been closed. var ErrClosed = errors.New("muxconn: closed") +// inboundQueue is the buffered capacity of the Push -> Read pipeline. +// It absorbs short Read stalls without applying back-pressure to the +// transport callback. Frames are typically smux-sized (well under +// defaultMaxPayloadSize == 12 KiB), so 256 amounts to a few MiB of +// in-flight data, which is enough for sustained throughput on every +// transport we have without unbounded growth on a stuck reader. +const inboundQueue = 256 + // Conn is an io.ReadWriteCloser over a [transport.Transport] with optional AEAD wrapping. +// +// Push produces decrypted plaintext frames into an internal channel; Read +// drains the channel and slices each frame across as many caller buffers +// as needed. The hot path is lock-free: a single producer (the transport +// callback) and a single consumer (smux's read loop) communicate via a +// buffered channel without any cond/mutex ping-pong. type Conn struct { ln transport.Transport send func([]byte) error cipher *crypto.Cipher - mu sync.Mutex - cond *sync.Cond - buf []byte - closed bool + in chan []byte + closeOnce sync.Once + closeCh chan struct{} + closed atomic.Bool + + // leftover holds the unread tail of the most recent frame popped + // from `in`. It is touched only by Read and so needs no + // synchronization. + leftover []byte } // New wires a Conn over the given transport. Push must be set as the // transport's OnData callback before this conn is used. func New(ln transport.Transport, cipher *crypto.Cipher) *Conn { - c := &Conn{ln: ln, send: ln.Send, cipher: cipher} - c.cond = sync.NewCond(&c.mu) - return c + return &Conn{ + ln: ln, + send: ln.Send, + cipher: cipher, + in: make(chan []byte, inboundQueue), + closeCh: make(chan struct{}), + } } // NewPeer wires a Conn whose writes are addressed to a specific transport peer. func NewPeer(ln transport.PeerTransport, cipher *crypto.Cipher, peerID string) *Conn { - c := &Conn{ + return &Conn{ ln: ln, send: func(data []byte) error { return ln.SendTo(peerID, data) }, - cipher: cipher, + cipher: cipher, + in: make(chan []byte, inboundQueue), + closeCh: make(chan struct{}), } - c.cond = sync.NewCond(&c.mu) - return c -} - -// Reset clears any buffered inbound bytes, re-arms a closed conn for writes, -// and unblocks pending Reads so the smux session on top of it exits cleanly. -// Use it when the link stays up but the peer's smux session has been rebuilt: -// the inbound byte stream (now indistinguishable random-looking data) must be -// parsed by the fresh smux state, not the old one. -func (c *Conn) Reset() { - c.mu.Lock() - c.buf = nil - c.closed = false - c.cond.Broadcast() - c.mu.Unlock() } // Push hands an encrypted wire payload (one OnData event) to the conn. +// +// On the producer side: decrypt, then either deliver via the inbound +// channel or, if the caller has Close'd or back-pressure can't drain in +// time, drop the frame. Blocking forever here would wedge the transport +// callback and trip its watchdog, so we cap waiting on closeCh. func (c *Conn) Push(ciphertext []byte) { pt, err := c.cipher.Decrypt(ciphertext) if err != nil { logger.Debugf("muxconn: decrypt failed, dropping frame: %v", err) return } - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { + if c.closed.Load() { return } - c.buf = append(c.buf, pt...) - c.cond.Broadcast() + select { + case c.in <- pt: + case <-c.closeCh: + } } -// Read implements io.Reader. Blocks until at least one byte is available. +// Read implements io.Reader. Blocks until at least one byte is available; +// after that, drains additional ready frames non-blockingly to fill p, so +// a single Read can absorb several queued frames in one go. This matches +// the prior cond/append-based implementation's concatenation behaviour +// and lets smux's bufio reader pull large chunks at a time. func (c *Conn) Read(p []byte) (int, error) { - c.mu.Lock() - defer c.mu.Unlock() - for !c.closed && len(c.buf) == 0 { - c.cond.Wait() + if len(p) == 0 { + return 0, nil } - if len(c.buf) == 0 { - return 0, io.EOF + if len(c.leftover) == 0 { + select { + case data, ok := <-c.in: + if !ok { + return 0, io.EOF + } + c.leftover = data + case <-c.closeCh: + // Drain any bytes that landed before close so a peer that + // shut us down right after a final write doesn't lose data. + select { + case data := <-c.in: + c.leftover = data + default: + return 0, io.EOF + } + } + } + n := copy(p, c.leftover) + c.leftover = c.leftover[n:] + + // Greedily pull additional frames already sitting in the queue, + // without blocking. This keeps the channel from accumulating a + // backlog when the consumer asks for a large buffer. + for n < len(p) && len(c.leftover) == 0 { + select { + case data, ok := <-c.in: + if !ok { + return n, nil + } + m := copy(p[n:], data) + n += m + if m < len(data) { + c.leftover = data[m:] + } + default: + return n, nil + } } - n := copy(p, c.buf) - c.buf = c.buf[n:] return n, nil } @@ -120,7 +173,7 @@ func (c *Conn) Write(p []byte) (int, error) { slowPollDelay = 2 * time.Millisecond ) for attempt := 0; ; attempt++ { - if c.isClosed() { + if c.closed.Load() { return 0, ErrClosed } if c.ln.CanSend() { @@ -145,18 +198,9 @@ func (c *Conn) Write(p []byte) (int, error) { // Close unblocks any pending Read with io.EOF. func (c *Conn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return nil - } - c.closed = true - c.cond.Broadcast() + c.closeOnce.Do(func() { + c.closed.Store(true) + close(c.closeCh) + }) return nil } - -func (c *Conn) isClosed() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.closed -}