From 913cabe222bba1b35902b010aedb3185780cd7d6 Mon Sep 17 00:00:00 2001 From: Qtozdec <56160254+qtozdec@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:26:15 +0300 Subject: [PATCH] Add mux control frames --- internal/client/client.go | 30 +++--- internal/mux/mux.go | 103 +++++++++++++------ internal/mux/mux_test.go | 61 +++++++++++ internal/server/server.go | 209 +++++++++++++++++++++++--------------- 4 files changed, 279 insertions(+), 124 deletions(-) create mode 100644 internal/mux/mux_test.go diff --git a/internal/client/client.go b/internal/client/client.go index 91de118..6b91218 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -103,18 +103,19 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s }) for i := 0; i < peerCount; i++ { + peerID := i peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData) if err != nil { return err } peer.SetEndedCallback(func(reason string) { - log.Printf("Client peer %d reported conference end: %s", i, reason) + log.Printf("Client peer %d reported conference end: %s", peerID, reason) cancel() }) c.peers = append(c.peers, peer) peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - log.Printf("Client peer %d reconnected - resetting multiplexer state", i) + log.Printf("Client peer %d reconnected - resetting multiplexer state", peerID) c.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := c.cipher.Encrypt(frame) @@ -130,11 +131,11 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s log.Println("Client multiplexer reset complete") }) - log.Printf("Connecting peer %d to Telemost...", i) + log.Printf("Connecting peer %d to Telemost...", peerID) if err := peer.Connect(runCtx); err != nil { return err } - log.Printf("Peer %d connected", i) + log.Printf("Peer %d connected", peerID) c.wg.Add(1) go func() { @@ -145,17 +146,18 @@ func Run(ctx context.Context, roomURL, keyHex string, socksPort int, duo bool, s time.Sleep(100 * time.Millisecond) - resetFrame := make([]byte, 12) - binary.BigEndian.PutUint32(resetFrame[0:4], c.clientID) - binary.BigEndian.PutUint16(resetFrame[4:6], 0xFFFF) - binary.BigEndian.PutUint16(resetFrame[6:8], 0xFFFF) - binary.BigEndian.PutUint32(resetFrame[8:12], 0) - encrypted, _ := cipher.Encrypt(resetFrame) - - for _, peer := range c.peers { - peer.Send(encrypted) + resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient) + encrypted, err := cipher.Encrypt(resetFrame) + if err != nil { + log.Printf("Failed to encrypt reset signal: %v", err) + } else { + for _, peer := range c.peers { + if err := peer.Send(encrypted); err != nil { + log.Printf("Failed to send reset signal to server: %v", err) + } + } + log.Printf("Sent reset signal to server (clientID=%d)", c.clientID) } - log.Printf("Sent reset signal to server (clientID=%d)", c.clientID) err = c.runSOCKS5(runCtx, socksPort, socksUser, socksPass) diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 3b625cf..83bee25 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -6,12 +6,25 @@ package mux import ( "encoding/binary" + "errors" "sync" "time" "github.com/openlibrecommunity/olcrtc/internal/logger" ) +const ( + ControlStreamID uint16 = 0xFFFF + ControlLength uint16 = 0xFFFF + + ControlResetClient uint32 = 1 +) + +type ControlFrame struct { + ClientID uint32 + Type uint32 +} + type Stream struct { ID uint16 ClientID uint32 @@ -144,24 +157,47 @@ func (m *Multiplexer) CloseStream(sid uint16) error { return m.onSend(frame) } -func (m *Multiplexer) HandleFrame(frame []byte) { - if len(frame) < 12 { - if len(frame) >= 8 { - clientID := binary.BigEndian.Uint32(frame[0:4]) - sid := binary.BigEndian.Uint16(frame[4:6]) - length := binary.BigEndian.Uint16(frame[6:8]) +func (m *Multiplexer) SendClientReset() error { + if m.clientID == 0 { + return errors.New("client reset requires a non-zero client id") + } + return m.onSend(BuildControlFrame(m.clientID, ControlResetClient)) +} - if sid == 0xFFFF && length == 0xFFFF { - m.mu.Lock() - for streamSid, stream := range m.streams { - if stream.ClientID == clientID { - stream.closed = true - delete(m.streams, streamSid) - } - } - m.mu.Unlock() - } - } +func BuildControlFrame(clientID uint32, controlType uint32) []byte { + frame := make([]byte, 12) + binary.BigEndian.PutUint32(frame[0:4], clientID) + binary.BigEndian.PutUint16(frame[4:6], ControlStreamID) + binary.BigEndian.PutUint16(frame[6:8], ControlLength) + binary.BigEndian.PutUint32(frame[8:12], controlType) + return frame +} + +func ParseControlFrame(frame []byte) (ControlFrame, bool) { + if len(frame) < 12 { + return ControlFrame{}, false + } + + sid := binary.BigEndian.Uint16(frame[4:6]) + length := binary.BigEndian.Uint16(frame[6:8]) + if sid != ControlStreamID || length != ControlLength { + return ControlFrame{}, false + } + + return ControlFrame{ + ClientID: binary.BigEndian.Uint32(frame[0:4]), + Type: binary.BigEndian.Uint32(frame[8:12]), + }, true +} + +func (m *Multiplexer) HandleFrame(frame []byte) { + control, ok := ParseControlFrame(frame) + if ok { + m.handleControlFrame(control) + return + } + + if len(frame) < 12 { return } @@ -170,18 +206,6 @@ func (m *Multiplexer) HandleFrame(frame []byte) { length := binary.BigEndian.Uint16(frame[6:8]) seq := binary.BigEndian.Uint32(frame[8:12]) - if sid == 0xFFFF && length == 0xFFFF { - m.mu.Lock() - for streamSid, stream := range m.streams { - if stream.ClientID == clientID { - stream.closed = true - delete(m.streams, streamSid) - } - } - m.mu.Unlock() - return - } - if length == 0 { m.mu.Lock() if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID { @@ -270,6 +294,27 @@ func (m *Multiplexer) HandleFrame(frame []byte) { } } +func (m *Multiplexer) handleControlFrame(control ControlFrame) { + switch control.Type { + case ControlResetClient: + m.ResetClient(control.ClientID) + default: + logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID) + } +} + +func (m *Multiplexer) ResetClient(clientID uint32) { + m.mu.Lock() + defer m.mu.Unlock() + + for streamSid, stream := range m.streams { + if stream.ClientID == clientID { + stream.closed = true + delete(m.streams, streamSid) + } + } +} + // waitForBufferSpace releases m.mu and waits until the stream's recvBuf has // room for `need` more bytes, then re-acquires the lock. Returns the (possibly // re-fetched) stream, or nil if the stream disappeared / was reset / closed. diff --git a/internal/mux/mux_test.go b/internal/mux/mux_test.go new file mode 100644 index 0000000..5a15612 --- /dev/null +++ b/internal/mux/mux_test.go @@ -0,0 +1,61 @@ +package mux + +import ( + "encoding/binary" + "testing" +) + +func TestParseControlFrame(t *testing.T) { + frame := BuildControlFrame(42, ControlResetClient) + + control, ok := ParseControlFrame(frame) + if !ok { + t.Fatal("expected control frame") + } + if control.ClientID != 42 { + t.Fatalf("ClientID = %d, want 42", control.ClientID) + } + if control.Type != ControlResetClient { + t.Fatalf("Type = %d, want %d", control.Type, ControlResetClient) + } +} + +func TestHandleControlResetClient(t *testing.T) { + m := New(0, func([]byte) error { return nil }) + + dataFrame := make([]byte, 13) + binary.BigEndian.PutUint32(dataFrame[0:4], 42) + binary.BigEndian.PutUint16(dataFrame[4:6], 7) + binary.BigEndian.PutUint16(dataFrame[6:8], 1) + binary.BigEndian.PutUint32(dataFrame[8:12], 0) + dataFrame[12] = 0xAA + + m.HandleFrame(dataFrame) + if stream := m.GetStream(7); stream == nil { + t.Fatal("expected data stream before reset") + } + + m.HandleFrame(BuildControlFrame(42, ControlResetClient)) + if stream := m.GetStream(7); stream != nil { + t.Fatal("expected data stream to be removed by client reset") + } +} + +func TestSendClientReset(t *testing.T) { + var sent []byte + m := New(99, func(frame []byte) error { + sent = append([]byte(nil), frame...) + return nil + }) + + if err := m.SendClientReset(); err != nil { + t.Fatalf("SendClientReset failed: %v", err) + } + control, ok := ParseControlFrame(sent) + if !ok { + t.Fatal("expected sent control frame") + } + if control.ClientID != 99 || control.Type != ControlResetClient { + t.Fatalf("control = %#v", control) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 23725e2..e02af91 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/rand" - "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -27,6 +26,8 @@ type Server struct { mux *mux.Multiplexer connections map[uint16]net.Conn connMu sync.RWMutex + streamPumps map[uint16]net.Conn + pumpMu sync.Mutex peerIdx atomic.Uint32 wg sync.WaitGroup dnsServer string @@ -76,6 +77,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string s := &Server{ cipher: cipher, connections: make(map[uint16]net.Conn), + streamPumps: make(map[uint16]net.Conn), peers: make([]*telemost.Peer, 0), dnsServer: dnsServer, } @@ -122,18 +124,19 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string }) for i := 0; i < peerCount; i++ { + peerID := i peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData) if err != nil { return err } peer.SetEndedCallback(func(reason string) { - log.Printf("Server peer %d reported conference end: %s", i, reason) + log.Printf("Server peer %d reported conference end: %s", peerID, reason) cancel() }) s.peers = append(s.peers, peer) peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - log.Printf("Server peer %d reconnected - resetting multiplexer state", i) + log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID) s.connMu.Lock() for sid, conn := range s.connections { @@ -160,11 +163,11 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string log.Println("Server multiplexer reset complete") }) - log.Printf("Connecting peer %d to Telemost...", i) + log.Printf("Connecting peer %d to Telemost...", peerID) if err := peer.Connect(runCtx); err != nil { return err } - log.Printf("Peer %d connected", i) + log.Printf("Peer %d connected", peerID) s.wg.Add(1) go func() { @@ -189,49 +192,29 @@ func (s *Server) onData(data []byte) { return } - if len(plaintext) >= 12 { - clientID := binary.BigEndian.Uint32(plaintext[0:4]) - sid := binary.BigEndian.Uint16(plaintext[4:6]) - length := binary.BigEndian.Uint16(plaintext[6:8]) - - if sid == 0xFFFF && length == 0xFFFF { - log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) - s.connMu.Lock() - for streamSid, conn := range s.connections { - stream := s.mux.GetStream(streamSid) - if stream != nil && stream.ClientID == clientID { - if conn != nil { - conn.Close() - } - delete(s.connections, streamSid) - } - } - s.connMu.Unlock() - } - } else if len(plaintext) >= 8 { - clientID := binary.BigEndian.Uint32(plaintext[0:4]) - sid := binary.BigEndian.Uint16(plaintext[4:6]) - length := binary.BigEndian.Uint16(plaintext[6:8]) - - if sid == 0xFFFF && length == 0xFFFF { - log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) - s.connMu.Lock() - for streamSid, conn := range s.connections { - stream := s.mux.GetStream(streamSid) - if stream != nil && stream.ClientID == clientID { - if conn != nil { - conn.Close() - } - delete(s.connections, streamSid) - } - } - s.connMu.Unlock() - } + if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient { + log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID) + s.closeClientConnections(control.ClientID) } s.mux.HandleFrame(plaintext) } +func (s *Server) closeClientConnections(clientID uint32) { + s.connMu.Lock() + defer s.connMu.Unlock() + + for streamSid, conn := range s.connections { + stream := s.mux.GetStream(streamSid) + if stream != nil && stream.ClientID == clientID { + if conn != nil { + conn.Close() + } + delete(s.connections, streamSid) + } + } +} + func (s *Server) run(ctx context.Context) error { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() @@ -263,50 +246,78 @@ func (s *Server) run(ctx context.Context) error { sids := s.mux.GetStreams() for _, sid := range sids { - go func(sid uint16) { - data := s.mux.ReadStream(sid) - if len(data) > 0 { - s.connMu.RLock() - conn, exists := s.connections[sid] - s.connMu.RUnlock() + if s.mux.StreamClosed(sid) { + s.closeStreamConnection(sid) + continue + } - if exists && conn != nil { - if _, err := conn.Write(data); err != nil { - s.mux.CloseStream(sid) - conn.Close() - s.connMu.Lock() - delete(s.connections, sid) - s.connMu.Unlock() - } - } else { - var req ConnectRequest - if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { - log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port) - s.connMu.Lock() - if oldConn, exists := s.connections[sid]; exists && oldConn != nil { - oldConn.Close() - } - s.connMu.Unlock() - go s.handleConnect(sid, req) - } - } - } + if s.hasConnection(sid) { + continue + } - if s.mux.StreamClosed(sid) { - s.connMu.Lock() - conn, exists := s.connections[sid] - if exists && conn != nil { - conn.Close() - delete(s.connections, sid) - } - s.connMu.Unlock() - } - }(sid) + data := s.mux.ReadStream(sid) + if len(data) == 0 { + continue + } + + var req ConnectRequest + if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { + log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port) + s.closeStreamConnection(sid) + go s.handleConnect(ctx, sid, req) + } } } } -func (s *Server) handleConnect(sid uint16, req ConnectRequest) { +func (s *Server) hasConnection(sid uint16) bool { + s.connMu.RLock() + defer s.connMu.RUnlock() + conn := s.connections[sid] + return conn != nil +} + +func (s *Server) closeStreamConnection(sid uint16) { + s.connMu.Lock() + conn := s.connections[sid] + if conn != nil { + conn.Close() + delete(s.connections, sid) + } + s.connMu.Unlock() +} + +func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) { + s.connMu.Lock() + conn := s.connections[sid] + if conn == expected { + conn.Close() + delete(s.connections, sid) + } + s.connMu.Unlock() +} + +func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool { + s.pumpMu.Lock() + defer s.pumpMu.Unlock() + if current := s.streamPumps[sid]; current == conn { + return false + } else if current != nil { + current.Close() + } + s.streamPumps[sid] = conn + return true +} + +func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) { + s.pumpMu.Lock() + if s.streamPumps[sid] == conn { + delete(s.streamPumps, sid) + } + s.pumpMu.Unlock() +} + +func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) { startTime := time.Now() addr := fmt.Sprintf("%s:%d", req.Addr, req.Port) logger.Verbose("Handling connect request sid=%d to %s", sid, addr) @@ -347,6 +358,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed) s.mux.SendData(sid, []byte{0x00}) + s.startStreamPump(ctx, sid, conn) go func() { defer func() { @@ -386,6 +398,41 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { }() } +func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { + if !s.markStreamPump(sid, conn) { + return + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.unmarkStreamPump(sid, conn) + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + data := s.mux.ReadStream(sid) + if len(data) > 0 { + if _, err := conn.Write(data); err != nil { + s.mux.CloseStream(sid) + s.closeStreamConnectionIfCurrent(sid, conn) + return + } + } + if s.mux.StreamClosed(sid) { + s.closeStreamConnectionIfCurrent(sid, conn) + return + } + } + } + }() +} + func (s *Server) canSendData() bool { for _, peer := range s.peers { if !peer.CanSend() {