diff --git a/internal/mux/mux.go b/internal/mux/mux.go index b575fd7..f7f6b51 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -6,28 +6,35 @@ import ( "errors" "fmt" "sync" - "time" "github.com/openlibrecommunity/olcrtc/internal/logger" ) var ( - ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive + ErrClientResetID = errors.New("client reset requires a non-zero client id") ) const ( - ControlStreamID uint16 = 0xFFFF //nolint:revive - ControlLength uint16 = 0xFFFF //nolint:revive + // Frame Header sizes + HeaderSize = 12 + // Special Stream IDs + ControlStreamID uint16 = 0xFFFF + + // Control Frame Types ControlResetClient uint32 = 1 + + // Frame Types (Internal to mux logic) + FrameTypeData uint16 = 0 + FrameTypeControl uint16 = 0xFFFF ) -type ControlFrame struct { //nolint:revive +type ControlFrame struct { ClientID uint32 Type uint32 } -type Stream struct { //nolint:revive +type Stream struct { ID uint16 ClientID uint32 recvBuf []byte @@ -37,13 +44,13 @@ type Stream struct { //nolint:revive outOfOrder map[uint32][]byte } -func (s *Stream) RecvBuf() []byte { //nolint:revive +func (s *Stream) RecvBuf() []byte { s.mu.Lock() defer s.mu.Unlock() return s.recvBuf } -type Multiplexer struct { //nolint:revive +type Multiplexer struct { streams map[uint16]*Stream nextID uint16 clientID uint32 @@ -55,10 +62,13 @@ type Multiplexer struct { //nolint:revive dataReadyMu sync.Mutex sendSeq map[uint16]uint32 sendSeqMu sync.Mutex + + // bufferCond is used to wait for space in receive buffers + bufferCond *sync.Cond } -func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive - return &Multiplexer{ +func New(clientID uint32, onSend func([]byte) error) *Multiplexer { + m := &Multiplexer{ streams: make(map[uint16]*Stream), nextID: 1, clientID: clientID, @@ -68,9 +78,11 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:rev dataReady: make(map[uint16]chan struct{}), sendSeq: make(map[uint16]uint32), } + m.bufferCond = sync.NewCond(&m.mu) + return m } -func (m *Multiplexer) OpenStream() uint16 { //nolint:revive +func (m *Multiplexer) OpenStream() uint16 { m.mu.Lock() defer m.mu.Unlock() @@ -93,7 +105,7 @@ func (m *Multiplexer) OpenStream() uint16 { //nolint:revive } } -func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive +func (m *Multiplexer) SendData(sid uint16, data []byte) error { m.mu.RLock() stream, exists := m.streams[sid] m.mu.RUnlock() @@ -122,12 +134,12 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive m.sendSeq[sid]++ m.sendSeqMu.Unlock() - frame := make([]byte, 12+len(chunk)) + frame := make([]byte, HeaderSize+len(chunk)) binary.BigEndian.PutUint32(frame[0:4], m.clientID) binary.BigEndian.PutUint16(frame[4:6], sid) - binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec + binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) binary.BigEndian.PutUint32(frame[8:12], seq) - copy(frame[12:], chunk) + copy(frame[HeaderSize:], chunk) if err := m.onSend(frame); err != nil { return fmt.Errorf("onSend failed: %w", err) @@ -137,7 +149,7 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive return nil } -func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive +func (m *Multiplexer) CloseStream(sid uint16) error { m.mu.Lock() defer m.mu.Unlock() @@ -149,7 +161,10 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive delete(m.sendSeq, sid) m.sendSeqMu.Unlock() - frame := make([]byte, 12) + // Notify anyone waiting for buffer space that a stream is closed + m.bufferCond.Broadcast() + + frame := make([]byte, HeaderSize) binary.BigEndian.PutUint32(frame[0:4], m.clientID) binary.BigEndian.PutUint16(frame[4:6], sid) binary.BigEndian.PutUint16(frame[6:8], 0) @@ -161,7 +176,7 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive return nil } -func (m *Multiplexer) SendClientReset() error { //nolint:revive +func (m *Multiplexer) SendClientReset() error { if m.clientID == 0 { return ErrClientResetID } @@ -171,23 +186,23 @@ func (m *Multiplexer) SendClientReset() error { //nolint:revive return nil } -func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive - frame := make([]byte, 12) +func BuildControlFrame(clientID uint32, controlType uint32) []byte { + frame := make([]byte, HeaderSize) binary.BigEndian.PutUint32(frame[0:4], clientID) binary.BigEndian.PutUint16(frame[4:6], ControlStreamID) - binary.BigEndian.PutUint16(frame[6:8], ControlLength) + binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control binary.BigEndian.PutUint32(frame[8:12], controlType) return frame } -func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive - if len(frame) < 12 { +func ParseControlFrame(frame []byte) (ControlFrame, bool) { + if len(frame) < HeaderSize { return ControlFrame{}, false } sid := binary.BigEndian.Uint16(frame[4:6]) length := binary.BigEndian.Uint16(frame[6:8]) - if sid != ControlStreamID || length != ControlLength { + if sid != ControlStreamID || length != 0xFFFF { return ControlFrame{}, false } @@ -197,14 +212,14 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive }, true } -func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive +func (m *Multiplexer) HandleFrame(frame []byte) { control, ok := ParseControlFrame(frame) if ok { m.handleControlFrame(control) return } - if len(frame) < 12 { + if len(frame) < HeaderSize { return } @@ -218,11 +233,11 @@ func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive return } - if len(frame) < 12+int(length) { + if len(frame) < HeaderSize+int(length) { return } - m.processDataFrame(sid, clientID, seq, frame[12:12+length]) + m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)]) } func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) { @@ -230,6 +245,7 @@ func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) { defer m.mu.Unlock() if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID { stream.closed = true + m.bufferCond.Broadcast() } } @@ -279,6 +295,7 @@ func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream { stream.closed = false stream.nextSeq = 0 stream.outOfOrder = make(map[uint32][]byte) + m.bufferCond.Broadcast() } return stream } @@ -319,7 +336,7 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) { } } -func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive +func (m *Multiplexer) ResetClient(clientID uint32) { m.mu.Lock() defer m.mu.Unlock() @@ -329,6 +346,7 @@ func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive delete(m.streams, streamSid) } } + m.bufferCond.Broadcast() } func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream { @@ -340,13 +358,12 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) if len(stream.recvBuf)+need <= m.maxBufferSize { return stream } - m.mu.Unlock() - time.Sleep(5 * time.Millisecond) - m.mu.Lock() + // Wait for space to become available + m.bufferCond.Wait() } } -func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive +func (m *Multiplexer) ReadStream(sid uint16) []byte { m.mu.Lock() defer m.mu.Unlock() @@ -357,10 +374,14 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive data := stream.recvBuf stream.recvBuf = make([]byte, 0) + + // Notify producers that space is now available + m.bufferCond.Broadcast() + return data } -func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive +func (m *Multiplexer) StreamClosed(sid uint16) bool { m.mu.RLock() defer m.mu.RUnlock() @@ -368,7 +389,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive return !exists || stream.closed } -func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive +func (m *Multiplexer) GetStreams() []uint16 { m.mu.RLock() defer m.mu.RUnlock() @@ -379,13 +400,13 @@ func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive return sids } -func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive +func (m *Multiplexer) GetStream(sid uint16) *Stream { m.mu.RLock() defer m.mu.RUnlock() return m.streams[sid] } -func (m *Multiplexer) Reset() { //nolint:revive +func (m *Multiplexer) Reset() { m.mu.Lock() defer m.mu.Unlock() @@ -399,16 +420,18 @@ func (m *Multiplexer) Reset() { //nolint:revive m.sendSeqMu.Lock() m.sendSeq = make(map[uint16]uint32) m.sendSeqMu.Unlock() + + m.bufferCond.Broadcast() } -func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive +func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { m.mu.Lock() defer m.mu.Unlock() m.onSend = onSend } -func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive +func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { m.dataReadyMu.Lock() defer m.dataReadyMu.Unlock() @@ -418,7 +441,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive return m.dataReady[sid] } -func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive +func (m *Multiplexer) CleanupDataChannel(sid uint16) { m.dataReadyMu.Lock() defer m.dataReadyMu.Unlock()