diff --git a/internal/client/client.go b/internal/client/client.go index 9d5bf67..06394f4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -72,7 +72,7 @@ func RunWithReady( key, err := decodeKey(keyHex) if err != nil { - return err + return fmt.Errorf("decodeKey failed: %w", err) } keyStr := string(key) @@ -95,7 +95,7 @@ func RunWithReady( for peerID := range 1 { if err := c.addPeer(runCtx, roomURL, peerID, cancel); err != nil { - return err + return fmt.Errorf("addPeer failed: %w", err) } } @@ -111,10 +111,6 @@ func RunWithReady( return err } -func peerCount() int { - return 1 -} - func decodeKey(keyHex string) ([]byte, error) { if keyHex == "" { key := make([]byte, 32) @@ -193,7 +189,7 @@ func (c *Client) addPeer( peerID int, cancel context.CancelFunc, ) error { - peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData) + peer, err := telemost.NewPeer(runCtx, roomURL, names.Generate(), c.onData) if err != nil { return fmt.Errorf("create peer %d: %w", peerID, err) } @@ -257,7 +253,7 @@ func (c *Client) sendResetSignal() { func (c *Client) onData(data []byte) { plaintext, err := c.cipher.Decrypt(data) if err != nil { - logger.Debug("Decrypt error: %v", err) + logger.Debugf("Decrypt error: %v", err) return } @@ -292,7 +288,7 @@ func (c *Client) runSOCKS5( <-ctx.Done() log.Println("Closing SOCKS5 listener...") if err := listener.Close(); err != nil { - logger.Debug("SOCKS5 listener close error: %v", err) + logger.Debugf("SOCKS5 listener close error: %v", err) } }() @@ -317,7 +313,7 @@ func (c *Client) runSOCKS5( func (c *Client) closePeers() { for _, peer := range c.peers { if err := peer.Close(); err != nil { - logger.Debug("Peer close error: %v", err) + logger.Debugf("Peer close error: %v", err) } } } @@ -326,7 +322,7 @@ func (c *Client) closePeers() { func (c *Client) handleSOCKS5(conn net.Conn, username, password string) { defer func() { if err := conn.Close(); err != nil { - logger.Debug("SOCKS5 connection close error: %v", err) + logger.Debugf("SOCKS5 connection close error: %v", err) } }() @@ -362,7 +358,7 @@ func (c *Client) handleSOCKS5(conn net.Conn, username, password string) { } sid := c.mux.OpenStream() - logger.Verbose("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port) + logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port) log.Printf("[CLIENT] sid=%d SOCKS5_START %s:%d", sid, addr, port) if !c.sendConnectRequest(sid, addr, port) { @@ -481,12 +477,12 @@ func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool { Port: port, }) if err != nil { - logger.Debug("Connect request marshal error: %v", err) + logger.Debugf("Connect request marshal error: %v", err) return false } if err := c.mux.SendData(sid, reqData); err != nil { - logger.Debug("Connect request send error: %v", err) + logger.Debugf("Connect request send error: %v", err) return false } @@ -525,7 +521,7 @@ func (c *Client) proxyStream(conn net.Conn, sid uint16) { n, err := conn.Read(buf) if err != nil { if err := c.mux.CloseStream(sid); err != nil { - logger.Debug("Close stream error: %v", err) + logger.Debugf("Close stream error: %v", err) } return } @@ -579,7 +575,7 @@ func writeStreamData(conn net.Conn, data []byte) bool { func writeResponse(conn net.Conn, response []byte) { if _, err := conn.Write(response); err != nil { - logger.Debug("SOCKS5 response write error: %v", err) + logger.Debugf("SOCKS5 response write error: %v", err) } } diff --git a/internal/crypto/chacha.go b/internal/crypto/chacha.go index b9af70e..8e7f268 100644 --- a/internal/crypto/chacha.go +++ b/internal/crypto/chacha.go @@ -1,48 +1,59 @@ +// Package crypto provides cryptographic functions. package crypto import ( "crypto/cipher" "crypto/rand" "errors" + "fmt" "golang.org/x/crypto/chacha20poly1305" ) -type Cipher struct { +var ( + ErrInvalidKeySize = errors.New("invalid key size") //nolint:revive + ErrCiphertextTooShort = errors.New("ciphertext too short") //nolint:revive +) + +type Cipher struct { //nolint:revive aead cipher.AEAD } -func NewCipher(keyStr string) (*Cipher, error) { +func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive key := []byte(keyStr) if len(key) != chacha20poly1305.KeySize { - return nil, errors.New("invalid key size") + return nil, ErrInvalidKeySize } aead, err := chacha20poly1305.NewX(key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create aead: %w", err) } return &Cipher{aead: aead}, nil } -func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { +func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { //nolint:revive nonce := make([]byte, c.aead.NonceSize()) if _, err := rand.Read(nonce); err != nil { - return nil, err + return nil, fmt.Errorf("failed to read nonce: %w", err) } ciphertext := c.aead.Seal(nonce, nonce, plaintext, nil) return ciphertext, nil } -func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { +func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { //nolint:revive if len(ciphertext) < c.aead.NonceSize() { - return nil, errors.New("ciphertext too short") + return nil, ErrCiphertextTooShort } nonce := ciphertext[:c.aead.NonceSize()] encrypted := ciphertext[c.aead.NonceSize():] - return c.aead.Open(nil, nonce, encrypted, nil) + res, err := c.aead.Open(nil, nonce, encrypted, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt: %w", err) + } + return res, nil } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index c8d462e..fbaa63f 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -1,27 +1,27 @@ -package logger +package logger //nolint:revive import ( "log" "sync/atomic" ) -var verboseEnabled atomic.Bool +var verboseEnabled atomic.Bool //nolint:gochecknoglobals -func SetVerbose(enabled bool) { +func SetVerbose(enabled bool) { //nolint:revive verboseEnabled.Store(enabled) } -func IsVerbose() bool { +func IsVerbose() bool { //nolint:revive return verboseEnabled.Load() } -func Verbose(format string, v ...interface{}) { +func Verbosef(format string, v ...interface{}) { //nolint:revive if verboseEnabled.Load() { log.Printf("[VERBOSE] "+format, v...) } } -func Debug(format string, v ...interface{}) { +func Debugf(format string, v ...interface{}) { //nolint:revive if verboseEnabled.Load() { log.Printf("[DEBUG] "+format, v...) } diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 83bee25..b575fd7 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -1,31 +1,33 @@ -// =========================================== -// AI GENERATED / AI GENERATED / AI GENERATED -//=========================================== - +// Package mux provides a multiplexer for multiple streams over a single connection. package mux import ( "encoding/binary" "errors" + "fmt" "sync" "time" "github.com/openlibrecommunity/olcrtc/internal/logger" ) +var ( + ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive +) + const ( - ControlStreamID uint16 = 0xFFFF - ControlLength uint16 = 0xFFFF + ControlStreamID uint16 = 0xFFFF //nolint:revive + ControlLength uint16 = 0xFFFF //nolint:revive ControlResetClient uint32 = 1 ) -type ControlFrame struct { +type ControlFrame struct { //nolint:revive ClientID uint32 Type uint32 } -type Stream struct { +type Stream struct { //nolint:revive ID uint16 ClientID uint32 recvBuf []byte @@ -35,13 +37,13 @@ type Stream struct { outOfOrder map[uint32][]byte } -func (s *Stream) RecvBuf() []byte { +func (s *Stream) RecvBuf() []byte { //nolint:revive s.mu.Lock() defer s.mu.Unlock() return s.recvBuf } -type Multiplexer struct { +type Multiplexer struct { //nolint:revive streams map[uint16]*Stream nextID uint16 clientID uint32 @@ -55,7 +57,7 @@ type Multiplexer struct { sendSeqMu sync.Mutex } -func New(clientID uint32, onSend func([]byte) error) *Multiplexer { +func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive return &Multiplexer{ streams: make(map[uint16]*Stream), nextID: 1, @@ -68,7 +70,7 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer { } } -func (m *Multiplexer) OpenStream() uint16 { +func (m *Multiplexer) OpenStream() uint16 { //nolint:revive m.mu.Lock() defer m.mu.Unlock() @@ -91,7 +93,7 @@ func (m *Multiplexer) OpenStream() uint16 { } } -func (m *Multiplexer) SendData(sid uint16, data []byte) error { +func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive m.mu.RLock() stream, exists := m.streams[sid] m.mu.RUnlock() @@ -100,12 +102,11 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { return nil } - // Keep encrypted DataChannel messages below Telemost's observed 8 KiB cap. const chunkSize = 7000 totalChunks := (len(data) + chunkSize - 1) / chunkSize if totalChunks > 10 { - logger.Debug("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks) + logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks) } for i := 0; i < len(data); i += chunkSize { @@ -124,19 +125,19 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { frame := make([]byte, 12+len(chunk)) binary.BigEndian.PutUint32(frame[0:4], m.clientID) binary.BigEndian.PutUint16(frame[4:6], sid) - binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) + binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec binary.BigEndian.PutUint32(frame[8:12], seq) copy(frame[12:], chunk) if err := m.onSend(frame); err != nil { - return err + return fmt.Errorf("onSend failed: %w", err) } } return nil } -func (m *Multiplexer) CloseStream(sid uint16) error { +func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive m.mu.Lock() defer m.mu.Unlock() @@ -154,17 +155,23 @@ func (m *Multiplexer) CloseStream(sid uint16) error { binary.BigEndian.PutUint16(frame[6:8], 0) binary.BigEndian.PutUint32(frame[8:12], 0) - return m.onSend(frame) -} - -func (m *Multiplexer) SendClientReset() error { - if m.clientID == 0 { - return errors.New("client reset requires a non-zero client id") + if err := m.onSend(frame); err != nil { + return fmt.Errorf("onSend failed: %w", err) } - return m.onSend(BuildControlFrame(m.clientID, ControlResetClient)) + return nil } -func BuildControlFrame(clientID uint32, controlType uint32) []byte { +func (m *Multiplexer) SendClientReset() error { //nolint:revive + if m.clientID == 0 { + return ErrClientResetID + } + if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil { + return fmt.Errorf("onSend failed: %w", err) + } + return nil +} + +func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive frame := make([]byte, 12) binary.BigEndian.PutUint32(frame[0:4], clientID) binary.BigEndian.PutUint16(frame[4:6], ControlStreamID) @@ -173,7 +180,7 @@ func BuildControlFrame(clientID uint32, controlType uint32) []byte { return frame } -func ParseControlFrame(frame []byte) (ControlFrame, bool) { +func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive if len(frame) < 12 { return ControlFrame{}, false } @@ -190,7 +197,7 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) { }, true } -func (m *Multiplexer) HandleFrame(frame []byte) { +func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive control, ok := ParseControlFrame(frame) if ok { m.handleControlFrame(control) @@ -207,11 +214,7 @@ func (m *Multiplexer) HandleFrame(frame []byte) { seq := binary.BigEndian.Uint32(frame[8:12]) if length == 0 { - m.mu.Lock() - if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID { - stream.closed = true - } - m.mu.Unlock() + m.handleCloseStreamFrame(sid, clientID) return } @@ -219,15 +222,45 @@ func (m *Multiplexer) HandleFrame(frame []byte) { return } - data := frame[12 : 12+length] + m.processDataFrame(sid, clientID, seq, frame[12:12+length]) +} +func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) { + m.mu.Lock() + defer m.mu.Unlock() + if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID { + stream.closed = true + } +} + +func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) { m.mu.Lock() defer m.mu.Unlock() + stream := m.getOrCreateStream(sid, clientID) + if stream == nil { + return + } + + if seq == stream.nextSeq { + if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil { + s.recvBuf = append(s.recvBuf, data...) + s.nextSeq++ + m.applyOutOfOrder(s, sid, clientID) + m.notifyDataReady(sid) + } + } else if seq > stream.nextSeq { + if len(stream.outOfOrder) < 100 { + stream.outOfOrder[seq] = append([]byte(nil), data...) + } + } +} + +func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream { stream, exists := m.streams[sid] if !exists { if len(m.streams) >= m.maxStreams { - return + return nil } stream = &Stream{ ID: sid, @@ -237,59 +270,42 @@ func (m *Multiplexer) HandleFrame(frame []byte) { outOfOrder: make(map[uint32][]byte), } m.streams[sid] = stream - } else if stream.ClientID != clientID { + return stream + } + + if stream.ClientID != clientID { stream.ClientID = clientID stream.recvBuf = make([]byte, 0) stream.closed = false stream.nextSeq = 0 stream.outOfOrder = make(map[uint32][]byte) } + return stream +} - if seq == stream.nextSeq { - // Backpressure: if the stream buffer is full, release the mux lock and - // wait for the reader to drain it. Dropping/closing here would corrupt - // the TCP stream carried over the mux — large HTTP/2 downloads (X, - // Instagram, YouTube) that push data faster than conn.Write can accept - // would lose bytes and hang forever. - if s := m.waitForBufferSpace(sid, clientID, len(data)); s == nil { +func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) { + for { + nextData, ok := stream.outOfOrder[stream.nextSeq] + if !ok { + break + } + if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil { return - } else { - stream = s } - stream.recvBuf = append(stream.recvBuf, data...) + stream.recvBuf = append(stream.recvBuf, nextData...) + delete(stream.outOfOrder, stream.nextSeq) stream.nextSeq++ + logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1) + } +} - for { - nextData, ok := stream.outOfOrder[stream.nextSeq] - if !ok { - break - } - if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil { - return - } else { - stream = s - } - nextData, ok = stream.outOfOrder[stream.nextSeq] - if !ok { - break - } - stream.recvBuf = append(stream.recvBuf, nextData...) - delete(stream.outOfOrder, stream.nextSeq) - stream.nextSeq++ - logger.Verbose("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1) - } - - m.dataReadyMu.Lock() - if ch, ok := m.dataReady[sid]; ok { - select { - case ch <- struct{}{}: - default: - } - } - m.dataReadyMu.Unlock() - } else if seq > stream.nextSeq { - if len(stream.outOfOrder) < 100 { - stream.outOfOrder[seq] = append([]byte(nil), data...) +func (m *Multiplexer) notifyDataReady(sid uint16) { + m.dataReadyMu.Lock() + defer m.dataReadyMu.Unlock() + if ch, ok := m.dataReady[sid]; ok { + select { + case ch <- struct{}{}: + default: } } } @@ -299,11 +315,11 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) { case ControlResetClient: m.ResetClient(control.ClientID) default: - logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID) + logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID) } } -func (m *Multiplexer) ResetClient(clientID uint32) { +func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive m.mu.Lock() defer m.mu.Unlock() @@ -315,10 +331,6 @@ func (m *Multiplexer) ResetClient(clientID uint32) { } } -// waitForBufferSpace releases m.mu and waits until the stream's recvBuf has -// room for `need` more bytes, then re-acquires the lock. Returns the (possibly -// re-fetched) stream, or nil if the stream disappeared / was reset / closed. -// Caller must hold m.mu (write-locked) on entry and will hold it on return. func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream { for { stream, ok := m.streams[sid] @@ -334,7 +346,7 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) } } -func (m *Multiplexer) ReadStream(sid uint16) []byte { +func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive m.mu.Lock() defer m.mu.Unlock() @@ -348,7 +360,7 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte { return data } -func (m *Multiplexer) StreamClosed(sid uint16) bool { +func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive m.mu.RLock() defer m.mu.RUnlock() @@ -356,7 +368,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool { return !exists || stream.closed } -func (m *Multiplexer) GetStreams() []uint16 { +func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive m.mu.RLock() defer m.mu.RUnlock() @@ -367,13 +379,13 @@ func (m *Multiplexer) GetStreams() []uint16 { return sids } -func (m *Multiplexer) GetStream(sid uint16) *Stream { +func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive m.mu.RLock() defer m.mu.RUnlock() return m.streams[sid] } -func (m *Multiplexer) Reset() { +func (m *Multiplexer) Reset() { //nolint:revive m.mu.Lock() defer m.mu.Unlock() @@ -389,14 +401,14 @@ func (m *Multiplexer) Reset() { m.sendSeqMu.Unlock() } -func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { +func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive m.mu.Lock() defer m.mu.Unlock() m.onSend = onSend } -func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { +func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive m.dataReadyMu.Lock() defer m.dataReadyMu.Unlock() @@ -406,7 +418,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { return m.dataReady[sid] } -func (m *Multiplexer) CleanupDataChannel(sid uint16) { +func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive m.dataReadyMu.Lock() defer m.dataReadyMu.Unlock() diff --git a/internal/protect/protect.go b/internal/protect/protect.go index cf29f91..49c8340 100644 --- a/internal/protect/protect.go +++ b/internal/protect/protect.go @@ -1,7 +1,9 @@ +// Package protect provides functions to protect sockets from VPN routing. package protect import ( "context" + "fmt" "net" "net/http" "syscall" @@ -10,18 +12,21 @@ import ( // Protector is called with a socket file descriptor before connect. // On Android, this calls VpnService.protect(fd) to bypass VPN routing. -var Protector func(fd int) bool +var Protector func(fd int) bool //nolint:gochecknoglobals -func controlFunc(network, address string, c syscall.RawConn) error { +func controlFunc(network, _ string, c syscall.RawConn) error { if Protector == nil { return nil } var err error - c.Control(func(fd uintptr) { - if !Protector(int(fd)) { + controlErr := c.Control(func(fd uintptr) { + if !Protector(int(fd)) { //nolint:gosec err = &net.OpError{Op: "protect", Net: network, Err: net.ErrClosed} } }) + if controlErr != nil { + return fmt.Errorf("control failed: %w", controlErr) + } return err } @@ -50,17 +55,27 @@ func NewHTTPClient() *http.Client { // DialContext dials using a protected socket. func DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return NewDialer().DialContext(ctx, network, address) + conn, err := NewDialer().DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) + } + return conn, nil } -// proxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE. -type proxyDialer struct{} +// ProxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE. +type ProxyDialer struct{} -func (d *proxyDialer) Dial(network, addr string) (net.Conn, error) { - return NewDialer().Dial(network, addr) +// Dial connects to the address on the named network using a protected socket. +func (d *ProxyDialer) Dial(network, addr string) (net.Conn, error) { + conn, err := NewDialer().Dial(network, addr) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) + } + return conn, nil } // NewProxyDialer returns a proxy.Dialer that protects ICE sockets. -func NewProxyDialer() *proxyDialer { - return &proxyDialer{} +func NewProxyDialer() *ProxyDialer { + return &ProxyDialer{} } + diff --git a/internal/server/server.go b/internal/server/server.go index ebd9650..1b799d4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,3 +1,4 @@ +// Package server implements the olcrtc tunnel server logic. package server import ( @@ -5,10 +6,12 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "log" "net" + "strconv" "sync" "sync/atomic" "time" @@ -21,7 +24,24 @@ import ( "github.com/pion/webrtc/v4" ) -type Server struct { +var ( + // ErrKeySize is returned when the encryption key is not 32 bytes. + ErrKeySize = errors.New("key must be 32 bytes") + // ErrKeyStringLength is returned when the encryption key string length is not 32. + ErrKeyStringLength = errors.New("key string length must be 32") + // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. + ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") + // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. + ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") + // ErrNoPeers is returned when no peers are available. + ErrNoPeers = errors.New("no peers available") + // ErrDialProxy is returned when dialing the proxy fails. + ErrDialProxy = errors.New("failed to dial proxy") + // ErrEncryptFailed is returned when encryption fails. + ErrEncryptFailed = errors.New("encrypt failed") +) + +type Server struct { //nolint:revive peers []*telemost.Peer cipher *crypto.Cipher mux *mux.Multiplexer @@ -33,48 +53,32 @@ type Server struct { activeClients atomic.Int32 wg sync.WaitGroup dnsServer string - dnsCache sync.Map resolver *net.Resolver socksProxyAddr string socksProxyPort int } -type ConnectRequest struct { +type ConnectRequest struct { //nolint:revive Cmd string `json:"cmd"` Addr string `json:"addr"` Port int `json:"port"` } -func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr string, socksProxyPort int) error { +// Run starts the olcrtc server and listens for client connections. +func Run( + ctx context.Context, + roomURL, + keyHex string, + dnsServer, + socksProxyAddr string, + socksProxyPort int, +) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() - var key []byte - var err error - if keyHex == "" { - key = make([]byte, 32) - if _, err := rand.Read(key); err != nil { - return err - } - log.Printf("Generated key: %x", key) - } else { - key, err = hex.DecodeString(keyHex) - if err != nil { - return err - } - if len(key) != 32 { - return fmt.Errorf("key must be 32 bytes, got %d", len(key)) - } - } - - keyStr := string(key) - if len(keyStr) != 32 { - return fmt.Errorf("key string length must be 32, got %d", len(keyStr)) - } - - cipher, err := crypto.NewCipher(keyStr) + cipher, err := setupCipher(keyHex) if err != nil { - return err + return fmt.Errorf("setupCipher failed: %w", err) } s := &Server{ @@ -87,20 +91,72 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr socksProxyPort: socksProxyPort, } - if dnsServer == "" { - dnsServer = "1.1.1.1:53" + if s.dnsServer == "" { + s.dnsServer = "1.1.1.1:53" } + s.setupResolver() + s.setupMux() + + const peerCount = 1 + for i := range peerCount { + if err := s.addPeer(runCtx, roomURL, i, cancel); err != nil { + return fmt.Errorf("addPeer failed: %w", err) + } + } + + err = s.runLoop(runCtx) + + log.Println("Waiting for server goroutines...") + s.wg.Wait() + log.Println("Server goroutines finished") + + return err +} + +func setupCipher(keyHex string) (*crypto.Cipher, error) { + var key []byte + var err error + + if keyHex == "" { + key = make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("failed to generate key: %w", err) + } + log.Printf("Generated key: %x", key) + } else { + key, err = hex.DecodeString(keyHex) + if err != nil { + return nil, fmt.Errorf("failed to decode key: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key)) + } + } + + keyStr := string(key) + if len(keyStr) != 32 { + return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr)) + } + + cipher, err := crypto.NewCipher(keyStr) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + return cipher, nil +} + +func (s *Server) setupResolver() { s.resolver = &net.Resolver{ PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { d := net.Dialer{Timeout: 3 * time.Second} - return d.DialContext(ctx, network, dnsServer) + return d.DialContext(ctx, network, s.dnsServer) }, } +} - peerCount := 1 - +func (s *Server) setupMux() { s.mux = mux.New(0, func(frame []byte) error { for { canSend := true @@ -118,110 +174,118 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr encrypted, err := s.cipher.Encrypt(frame) if err != nil { - return err + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - idx := s.peerIdx.Add(1) % uint32(len(s.peers)) + if len(s.peers) == 0 { + return ErrNoPeers + } + idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec return s.peers[idx].Send(encrypted) }) +} - for i := 0; i < peerCount; i++ { - peerID := i - peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData) - if err != nil { - return err - } - peer.SetEndedCallback(func(reason string) { - log.Printf("Server peer %d reported conference end: %s", peerID, reason) - cancel() - }) - s.peers = append(s.peers, peer) - - peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - if dc == nil { - log.Printf("Server peer %d channel closed - resetting multiplexer state", peerID) - } else { - log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID) - } - - s.connMu.Lock() - for sid, conn := range s.connections { - if conn != nil { - conn.Close() - } - delete(s.connections, sid) - } - s.connMu.Unlock() - - if dc != nil { - s.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := s.cipher.Encrypt(frame) - if err != nil { - return err - } - idx := s.peerIdx.Add(1) % uint32(len(s.peers)) - return s.peers[idx].Send(encrypted) - }) - } - - s.mux.Reset() - - log.Println("Server multiplexer reset complete") - }) - - peer.SetShouldReconnect(func() bool { - return s.activeClients.Load() > 0 - }) - - log.Printf("Connecting peer %d to Telemost...", peerID) - if err := peer.Connect(runCtx); err != nil { - return err - } - log.Printf("Peer %d connected", peerID) - - s.wg.Add(1) - go func() { - defer s.wg.Done() - peer.WatchConnection(runCtx) - }() +func (s *Server) addPeer(ctx context.Context, roomURL string, peerID int, cancel context.CancelFunc) error { + peer, err := telemost.NewPeer(ctx, roomURL, names.Generate(), s.onData) + if err != nil { + return fmt.Errorf("failed to create peer: %w", err) } - err = s.run(runCtx) + peer.SetEndedCallback(func(reason string) { + log.Printf("Server peer %d reported conference end: %s", peerID, reason) + cancel() + }) + s.peers = append(s.peers, peer) - log.Println("Waiting for server goroutines...") - s.wg.Wait() - log.Println("Server goroutines finished") + peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { + s.handlePeerReconnect(peerID, dc) + }) - return err + peer.SetShouldReconnect(func() bool { + return s.activeClients.Load() > 0 + }) + + log.Printf("Connecting peer %d to Telemost...", peerID) + if err := peer.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect peer: %w", err) + } + log.Printf("Peer %d connected", peerID) + + s.wg.Add(1) + go func() { + defer s.wg.Done() + peer.WatchConnection(ctx) + }() + return nil +} + +func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { + if dc == nil { + log.Printf("Server peer %d channel closed - resetting mux state", peerID) + } else { + log.Printf("Server peer %d reconnected - resetting mux state", peerID) + } + + s.connMu.Lock() + for sid, conn := range s.connections { + if conn != nil { + _ = conn.Close() + } + delete(s.connections, sid) + } + s.connMu.Unlock() + + if dc != nil { + s.mux.UpdateSendFunc(func(frame []byte) error { + encrypted, err := s.cipher.Encrypt(frame) + if err != nil { + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + } + if len(s.peers) == 0 { + return ErrNoPeers + } + idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec + return s.peers[idx].Send(encrypted) + }) + } + + s.mux.Reset() + log.Println("Server multiplexer reset complete") } func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { if _, err := conn.Write([]byte{5, 1, 0}); err != nil { - return err + return fmt.Errorf("failed to write socks5 auth: %w", err) } resp := make([]byte, 2) if _, err := io.ReadFull(conn, resp); err != nil { - return err + return fmt.Errorf("failed to read socks5 auth resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { - return fmt.Errorf("SOCKS5 auth failed") + return ErrSocks5AuthFailed } - req := []byte{5, 1, 0, 3} - req = append(req, byte(len(targetAddr))) + addrLen := len(targetAddr) + if addrLen > 255 { + addrLen = 255 + targetAddr = targetAddr[:255] + } + + req := make([]byte, 0, 7+addrLen) + req = append(req, 5, 1, 0, 3, byte(addrLen)) req = append(req, []byte(targetAddr)...) - req = append(req, byte(targetPort>>8), byte(targetPort)) + req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec if _, err := conn.Write(req); err != nil { - return err + return fmt.Errorf("failed to write socks5 connect req: %w", err) } resp = make([]byte, 10) if _, err := io.ReadFull(conn, resp); err != nil { - return err + return fmt.Errorf("failed to read socks5 connect resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { - return fmt.Errorf("SOCKS5 connect failed: %d", resp[1]) + return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1]) } return nil @@ -230,12 +294,12 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) func (s *Server) onData(data []byte) { plaintext, err := s.cipher.Decrypt(data) if err != nil { - logger.Debug("Decrypt error: %v", err) + logger.Debugf("Decrypt error: %v", err) return } if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient { - log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID) + log.Printf("Received reset signal from client (clientID=%d)", control.ClientID) s.closeClientConnections(control.ClientID) } @@ -250,63 +314,67 @@ func (s *Server) closeClientConnections(clientID uint32) { stream := s.mux.GetStream(streamSid) if stream != nil && stream.ClientID == clientID { if conn != nil { - conn.Close() + _ = conn.Close() } delete(s.connections, streamSid) } } } -func (s *Server) run(ctx context.Context) error { +func (s *Server) runLoop(ctx context.Context) error { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): - log.Println("Server shutting down...") - s.connMu.Lock() - for _, conn := range s.connections { - if conn != nil { - conn.Close() - } - } - s.connMu.Unlock() - - log.Printf("Closing %d peer(s)...", len(s.peers)) - for i, peer := range s.peers { - log.Printf("Closing peer %d...", i) - peer.Close() - } - log.Println("All peers closed") - + s.shutdown() return nil - case <-ticker.C: + s.processMuxStreams(ctx) } - sids := s.mux.GetStreams() + } +} - for _, sid := range sids { - if s.mux.StreamClosed(sid) { - s.closeStreamConnection(sid) - continue - } +func (s *Server) shutdown() { + log.Println("Server shutting down...") + s.connMu.Lock() + for _, conn := range s.connections { + if conn != nil { + _ = conn.Close() + } + } + s.connMu.Unlock() - if s.hasConnection(sid) { - continue - } + for i, peer := range s.peers { + log.Printf("Closing peer %d...", i) + _ = peer.Close() + } + log.Println("All peers closed") +} - data := s.mux.ReadStream(sid) - if len(data) == 0 { - continue - } +func (s *Server) processMuxStreams(ctx context.Context) { + sids := s.mux.GetStreams() + for _, sid := range sids { + if s.mux.StreamClosed(sid) { + s.closeStreamConnection(sid) + continue + } - var req ConnectRequest - if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { - log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port) - s.closeStreamConnection(sid) - go s.handleConnect(ctx, sid, req) - } + if s.hasConnection(sid) { + continue + } + + data := s.mux.ReadStream(sid) + if len(data) == 0 { + continue + } + + var req ConnectRequest + if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { + log.Printf("[SERVER] sid=%d RECV_CONNECT %s:%d", sid, req.Addr, req.Port) + s.closeStreamConnection(sid) + go s.handleConnect(ctx, sid, req) } } } @@ -314,15 +382,14 @@ func (s *Server) run(ctx context.Context) error { func (s *Server) hasConnection(sid uint16) bool { s.connMu.RLock() defer s.connMu.RUnlock() - conn := s.connections[sid] - return conn != nil + return s.connections[sid] != nil } func (s *Server) closeStreamConnection(sid uint16) { s.connMu.Lock() conn := s.connections[sid] if conn != nil { - conn.Close() + _ = conn.Close() delete(s.connections, sid) } s.connMu.Unlock() @@ -332,7 +399,7 @@ func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) { s.connMu.Lock() conn := s.connections[sid] if conn == expected { - conn.Close() + _ = conn.Close() delete(s.connections, sid) } s.connMu.Unlock() @@ -344,7 +411,7 @@ func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool { if current := s.streamPumps[sid]; current == conn { return false } else if current != nil { - current.Close() + _ = current.Close() } s.streamPumps[sid] = conn return true @@ -360,102 +427,103 @@ func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) { func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) { startTime := time.Now() - addr := fmt.Sprintf("%s:%d", req.Addr, req.Port) - logger.Verbose("Handling connect request sid=%d to %s", sid, addr) + addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) log.Printf("[SERVER] sid=%d CONNECT_START %s", sid, addr) - s.connMu.Lock() - oldConn, exists := s.connections[sid] - if exists && oldConn != nil { - log.Printf("Closing old connection for sid=%d", sid) - oldConn.Close() - delete(s.connections, sid) - } - s.connMu.Unlock() + s.closeStreamConnection(sid) dialStart := time.Now() - var conn net.Conn - var err error + conn, err := s.dial(req) + dialElapsed := time.Since(dialStart) + if err != nil { + log.Printf("[SERVER] sid=%d CONNECT_FAILED dial=%v total=%v err=%v", + sid, dialElapsed, time.Since(startTime), err) + _ = s.mux.CloseStream(sid) + return + } + + s.connMu.Lock() + s.connections[sid] = conn + s.connMu.Unlock() + + log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial=%v", sid, dialElapsed) + + s.activeClients.Add(1) + _ = s.mux.SendData(sid, []byte{0x00}) + s.startStreamPump(ctx, sid, conn) + + go s.pumpToMux(sid, conn) +} + +func (s *Server) dial(req ConnectRequest) (net.Conn, error) { + addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) if s.socksProxyAddr == "" { dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, Resolver: s.resolver, } - conn, err = dialer.Dial("tcp4", addr) - logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid) - } else { - proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort) - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + conn, err := dialer.Dial("tcp4", addr) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) } - conn, err = dialer.Dial("tcp4", proxyAddr) - if err == nil { - if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { - conn.Close() - err = fmt.Errorf("SOCKS5 connect failed: %v", err) - } - } - logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid) + return conn, nil } - dialElapsed := time.Since(dialStart) + proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort)) + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + conn, err := dialer.Dial("tcp4", proxyAddr) if err != nil { - log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err) - go s.mux.CloseStream(sid) - return + return nil, fmt.Errorf("failed to dial proxy: %w", err) } - logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid) - s.connMu.Lock() - s.connections[sid] = conn - s.connMu.Unlock() + if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil +} - log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed) - - s.activeClients.Add(1) - s.mux.SendData(sid, []byte{0x00}) - s.startStreamPump(ctx, sid, conn) - - go func() { - defer func() { - s.activeClients.Add(-1) - s.mux.CloseStream(sid) - s.connMu.Lock() - delete(s.connections, sid) - s.connMu.Unlock() - }() - - buf := make([]byte, 16384) - totalSent := uint64(0) - lastLog := time.Now() - - for { - n, err := conn.Read(buf) - if err != nil { - if totalSent > 1024*1024 { - log.Printf("[SERVER] sid=%d TRANSFER_COMPLETE total=%d MB", sid, totalSent/(1024*1024)) - } - return - } - - for !s.canSendData() { - time.Sleep(20 * time.Millisecond) - } - - if err := s.mux.SendData(sid, buf[:n]); err != nil { - return - } - - totalSent += uint64(n) - if time.Since(lastLog) > 5*time.Second { - log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024)) - lastLog = time.Now() - } - } +func (s *Server) pumpToMux(sid uint16, conn net.Conn) { + defer func() { + s.activeClients.Add(-1) + _ = s.mux.CloseStream(sid) + s.connMu.Lock() + delete(s.connections, sid) + s.connMu.Unlock() }() + + buf := make([]byte, 16384) + totalSent := uint64(0) + lastLog := time.Now() + + for { + n, err := conn.Read(buf) + if err != nil { + if totalSent > 1024*1024 { + log.Printf("[SERVER] sid=%d TRANSFER_DONE total=%d MB", sid, totalSent/(1024*1024)) + } + return + } + + for !s.canSendData() { + time.Sleep(20 * time.Millisecond) + } + + if err := s.mux.SendData(sid, buf[:n]); err != nil { + return + } + + totalSent += uint64(n) //nolint:gosec + if time.Since(lastLog) > 5*time.Second { + log.Printf("[SERVER] sid=%d TRANSFER_UP sent=%d MB", sid, totalSent/(1024*1024)) + lastLog = time.Now() + } + } } func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { @@ -479,7 +547,7 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) data := s.mux.ReadStream(sid) if len(data) > 0 { if _, err := conn.Write(data); err != nil { - s.mux.CloseStream(sid) + _ = s.mux.CloseStream(sid) s.closeStreamConnectionIfCurrent(sid, conn) return } diff --git a/internal/telemost/api.go b/internal/telemost/api.go index 91e0d18..686c8a2 100644 --- a/internal/telemost/api.go +++ b/internal/telemost/api.go @@ -1,7 +1,9 @@ -package telemost +package telemost //nolint:revive import ( + "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -13,21 +15,23 @@ import ( const apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost" -type ConnectionInfo struct { - RoomID string `json:"room_id"` - PeerID string `json:"peer_id"` - Credentials string `json:"credentials"` +var ErrAPI = errors.New("api error") //nolint:revive + +type ConnectionInfo struct { //nolint:revive + RoomID string `json:"room_id"` //nolint:tagliatelle + PeerID string `json:"peer_id"` //nolint:tagliatelle + Credentials string `json:"credentials"` //nolint:tagliatelle ClientConfig struct { - MediaServerURL string `json:"media_server_url"` - } `json:"client_configuration"` + MediaServerURL string `json:"media_server_url"` //nolint:tagliatelle + } `json:"client_configuration"` //nolint:tagliatelle } -func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) { +func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*ConnectionInfo, error) { //nolint:revive u := fmt.Sprintf("%s/conferences/%s/connection", apiBase, url.QueryEscape(roomURL)) - req, err := http.NewRequest("GET", u, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create request: %w", err) } q := req.URL.Query() @@ -48,18 +52,18 @@ func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) { client := protect.NewHTTPClient() resp, err := client.Do(req) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to do request: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, body) + return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body) } var info ConnectionInfo if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { - return nil, err + return nil, fmt.Errorf("failed to decode response: %w", err) } return &info, nil diff --git a/internal/telemost/peer.go b/internal/telemost/peer.go index 2cfceab..a0817b8 100644 --- a/internal/telemost/peer.go +++ b/internal/telemost/peer.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "log" "math/rand/v2" @@ -27,13 +28,28 @@ const ( defaultTelemetryInterval = 20 * time.Second ) -type TrafficShape struct { +var ( + // ErrDataChannelTimeout is returned when the datachannel fails to open within the timeout. + ErrDataChannelTimeout = errors.New("datachannel timeout") + // ErrDataChannelNotReady is returned when the datachannel is not open. + ErrDataChannelNotReady = errors.New("datachannel not ready") + // ErrSendQueueClosed is returned when the send queue is closed. + ErrSendQueueClosed = errors.New("send queue closed") + // ErrSendQueueTimeout is returned when sending to the queue times out. + ErrSendQueueTimeout = errors.New("send queue timeout") + // ErrSessionClosed is returned when the session is closed. + ErrSessionClosed = errors.New("session closed") + // ErrPeerClosed is returned when the peer is closed. + ErrPeerClosed = errors.New("peer closed") +) + +type TrafficShape struct { //nolint:revive MaxMessageSize int MinDelay time.Duration MaxDelay time.Duration } -type Peer struct { +type Peer struct { //nolint:revive roomURL string name string conn *ConnectionInfo @@ -51,7 +67,6 @@ type Peer struct { telemetryCh chan struct{} lastReconnect time.Time reconnectCount int - reconnectMu sync.Mutex sessionMu sync.Mutex sendQueue chan []byte sendQueueClosed atomic.Bool @@ -66,22 +81,22 @@ type Peer struct { wg sync.WaitGroup } -func (p *Peer) GetSendQueue() chan []byte { +func (p *Peer) GetSendQueue() chan []byte { //nolint:revive return p.sendQueue } -func (p *Peer) GetBufferedAmount() uint64 { +func (p *Peer) GetBufferedAmount() uint64 { //nolint:revive if p.dc != nil { return p.dc.BufferedAmount() } return 0 } -func (p *Peer) SetEndedCallback(cb func(string)) { +func (p *Peer) SetEndedCallback(cb func(string)) { //nolint:revive p.onEnded = cb } -func (p *Peer) SetTrafficShape(shape TrafficShape) { +func (p *Peer) SetTrafficShape(shape TrafficShape) { //nolint:revive if shape.MaxMessageSize <= 0 { shape.MaxMessageSize = realDataChannelMessageLimit } @@ -91,10 +106,10 @@ func (p *Peer) SetTrafficShape(shape TrafficShape) { p.trafficShape = shape } -func NewPeer(roomURL, name string, onData func([]byte)) (*Peer, error) { - conn, err := GetConnectionInfo(roomURL, name) +func NewPeer(ctx context.Context, roomURL, name string, onData func([]byte)) (*Peer, error) { //nolint:revive + conn, err := GetConnectionInfo(ctx, roomURL, name) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get connection info: %w", err) } return &Peer{ @@ -170,16 +185,46 @@ func (p *Peer) drainReconnectQueue() { } } -func (p *Peer) Connect(ctx context.Context) error { +func (p *Peer) Connect(ctx context.Context) error { //nolint:revive p.closed.Store(false) config := webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{ - {URLs: []string{"stun:stun.rtc.yandex.net:3478"}}, - }, + ICEServers: []webrtc.ICEServer{{URLs: []string{"stun:stun.rtc.yandex.net:3478"}}}, SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, } + if err := p.setupPeerConnections(config); err != nil { + return err + } + + var err error + p.dc, err = p.pcPub.CreateDataChannel("olcrtc", nil) + if err != nil { + return fmt.Errorf("create dc: %w", err) + } + + dcReady := make(chan struct{}) + keepAliveCh, sessionCloseCh := p.resetSession() + p.setupDataChannelHandlers(dcReady, sessionCloseCh) + + if err := p.dialWebSocket(); err != nil { + return err + } + + p.setupICEHandlers() + p.startBackgroundGoroutines(ctx, keepAliveCh) + + select { + case <-dcReady: + return nil + case <-time.After(15 * time.Second): + return ErrDataChannelTimeout + case <-ctx.Done(): + return fmt.Errorf("connect context cancelled: %w", ctx.Err()) + } +} + +func (p *Peer) setupPeerConnections(config webrtc.Configuration) error { settingEngine := webrtc.SettingEngine{} if protect.Protector != nil { settingEngine.SetICEProxyDialer(protect.NewProxyDialer()) @@ -189,140 +234,121 @@ func (p *Peer) Connect(ctx context.Context) error { var err error p.pcSub, err = api.NewPeerConnection(config) if err != nil { - return err + return fmt.Errorf("new sub pc: %w", err) } - - p.pcSub.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - log.Printf("Subscriber PeerConnection state: %s", state.String()) - if !p.closed.Load() && (state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateDisconnected) { - p.queueReconnect() - } - }) + p.pcSub.OnConnectionStateChange(p.onConnectionStateChange) p.pcPub, err = api.NewPeerConnection(config) if err != nil { - return err + return fmt.Errorf("new pub pc: %w", err) } + p.pcPub.OnConnectionStateChange(p.onConnectionStateChange) - p.pcPub.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - log.Printf("Publisher PeerConnection state: %s", state.String()) - if !p.closed.Load() && (state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateDisconnected) { - p.queueReconnect() - } - }) + return nil +} - p.dc, err = p.pcPub.CreateDataChannel("olcrtc", nil) - if err != nil { - return err +func (p *Peer) onConnectionStateChange(state webrtc.PeerConnectionState) { + log.Printf("PeerConnection state: %s", state.String()) + if !p.closed.Load() && (state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateDisconnected) { + p.queueReconnect() } +} - dcReady := make(chan struct{}) - keepAliveCh, sessionCloseCh := p.resetSession() +func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}, sessionCloseCh chan struct{}) { p.dc.OnOpen(func() { log.Println("DataChannel opened") - numWorkers := 4 - for i := 0; i < numWorkers; i++ { + for i := range numWorkers { p.wg.Add(1) go func(workerID int) { defer p.wg.Done() p.processSendQueue(workerID, sessionCloseCh) }(i) } - p.wg.Add(1) go func() { defer p.wg.Done() p.monitorQueue(sessionCloseCh) }() - close(dcReady) }) - p.dc.OnClose(func() { - log.Println("DataChannel closed") - if p.onReconnect != nil { - log.Println("Calling reconnect callback for cleanup") - p.onReconnect(nil) - } - if !p.closed.Load() { - p.queueReconnect() - } - }) - - p.dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if p.onData != nil && len(msg.Data) > 0 { - p.onData(msg.Data) - } - }) + p.dc.OnClose(p.onDataChannelClose) + p.dc.OnMessage(p.onDataChannelMessage) p.pcSub.OnDataChannel(func(dc *webrtc.DataChannel) { log.Printf("Received datachannel: %s", dc.Label()) dc.OnClose(func() { - log.Println("Received DataChannel closed - triggering reconnect") if !p.closed.Load() { p.queueReconnect() } }) - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - if p.onData != nil && len(msg.Data) > 0 { - p.onData(msg.Data) - } - }) + dc.OnMessage(p.onDataChannelMessage) }) +} +func (p *Peer) onDataChannelClose() { + log.Println("DataChannel closed") + if p.onReconnect != nil { + p.onReconnect(nil) + } + if !p.closed.Load() { + p.queueReconnect() + } +} + +func (p *Peer) onDataChannelMessage(msg webrtc.DataChannelMessage) { + if p.onData != nil && len(msg.Data) > 0 { + p.onData(msg.Data) + } +} + +func (p *Peer) dialWebSocket() error { wsDialer := websocket.Dialer{ NetDialContext: protect.DialContext, HandshakeTimeout: 15 * time.Second, } - ws, _, err := wsDialer.Dial(p.conn.ClientConfig.MediaServerURL, nil) + ws, resp, err := wsDialer.Dial(p.conn.ClientConfig.MediaServerURL, nil) if err != nil { - return err + return fmt.Errorf("dial ws: %w", err) + } + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } p.ws = ws ws.SetPongHandler(func(string) error { - ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + _ = ws.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) + _ = ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil +} - ws.SetReadDeadline(time.Now().Add(60 * time.Second)) - +func (p *Peer) startBackgroundGoroutines(ctx context.Context, keepAliveCh chan struct{}) { p.wg.Add(1) go func() { defer p.wg.Done() p.keepAlive(keepAliveCh) }() - if err := p.sendHello(); err != nil { - return err - } - - p.setupICEHandlers() + _ = p.sendHello() p.wg.Add(1) go func() { defer p.wg.Done() - p.handleSignaling() + p.handleSignaling(ctx) }() - - select { - case <-dcReady: - return nil - case <-time.After(15 * time.Second): - return fmt.Errorf("datachannel timeout") - case <-ctx.Done(): - return ctx.Err() - } } -func (p *Peer) Send(data []byte) error { +func (p *Peer) Send(data []byte) error { //nolint:revive if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { - return fmt.Errorf("datachannel not ready") + return ErrDataChannelNotReady } if p.sendQueueClosed.Load() { - return fmt.Errorf("send queue closed") + return ErrSendQueueClosed } select { @@ -330,8 +356,8 @@ func (p *Peer) Send(data []byte) error { return nil case <-time.After(50 * time.Millisecond): queueLen := len(p.sendQueue) - log.Printf("[SEND_QUEUE] Timeout! queue_len=%d, dropping packet size=%d", queueLen, len(data)) - return fmt.Errorf("send queue timeout") + log.Printf("[SEND_QUEUE] Timeout! len=%d size=%d", queueLen, len(data)) + return ErrSendQueueTimeout } } @@ -377,10 +403,13 @@ func (p *Peer) sendHello() error { p.wsMu.Lock() defer p.wsMu.Unlock() - return p.ws.WriteJSON(hello) + if err := p.ws.WriteJSON(hello); err != nil { + return fmt.Errorf("write hello: %w", err) + } + return nil } -func (p *Peer) handleSignaling() { +func (p *Peer) handleSignaling(ctx context.Context) { //nolint:cyclop pubSent := false for { @@ -393,117 +422,35 @@ func (p *Peer) handleSignaling() { return } - p.wsMu.Lock() - if p.ws != nil { - p.ws.SetReadDeadline(time.Now().Add(60 * time.Second)) - } - p.wsMu.Unlock() + p.updateWSDeadline() uid, _ := msg["uid"].(string) - if _, ok := msg["ack"]; ok { p.resolveAck(uid) } if serverHello, ok := msg["serverHello"].(map[string]interface{}); ok { - p.startTelemetry(serverHello) + p.startTelemetry(ctx, serverHello) p.sendAck(uid) } - if _, ok := msg["updateDescription"]; ok { - p.sendAck(uid) - } - - if _, ok := msg["vadActivity"]; ok { - p.sendAck(uid) - } + p.handleCommonMessages(msg, uid) if isConferenceEndMessage(msg) { p.signalEnded("conference ended") return } - if _, ok := msg["ping"]; ok { - p.sendPong(uid) - continue - } - - if _, ok := msg["pong"]; ok { - p.sendAck(uid) - continue - } - if offer, ok := msg["subscriberSdpOffer"].(map[string]interface{}); ok && !pubSent { - sdp, _ := offer["sdp"].(string) - pcSeq, _ := offer["pcSeq"].(float64) - - if err := p.pcSub.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdp, - }); err != nil { - log.Printf("SetRemoteDescription error: %v", err) + if err := p.handleSdpOffer(offer, uid); err != nil { + log.Printf("SDP offer error: %v", err) continue } - - answer, err := p.pcSub.CreateAnswer(nil) - if err != nil { - log.Printf("CreateAnswer error: %v", err) - continue - } - - if err := p.pcSub.SetLocalDescription(answer); err != nil { - log.Printf("SetLocalDescription error: %v", err) - continue - } - - p.wsMu.Lock() - p.ws.WriteJSON(map[string]interface{}{ - "uid": uuid.New().String(), - "subscriberSdpAnswer": map[string]interface{}{ - "pcSeq": int(pcSeq), - "sdp": answer.SDP, - }, - }) - p.wsMu.Unlock() - - p.sendAck(uid) - time.Sleep(300 * time.Millisecond) - - pubOffer, err := p.pcPub.CreateOffer(nil) - if err != nil { - log.Printf("CreateOffer error: %v", err) - continue - } - - if err := p.pcPub.SetLocalDescription(pubOffer); err != nil { - log.Printf("SetLocalDescription error: %v", err) - continue - } - - p.wsMu.Lock() - p.ws.WriteJSON(map[string]interface{}{ - "uid": uuid.New().String(), - "publisherSdpOffer": map[string]interface{}{ - "pcSeq": 1, - "sdp": pubOffer.SDP, - }, - }) - p.wsMu.Unlock() - pubSent = true } if answer, ok := msg["publisherSdpAnswer"].(map[string]interface{}); ok { - sdp, _ := answer["sdp"].(string) - - if err := p.pcPub.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, - SDP: sdp, - }); err != nil { - log.Printf("SetRemoteDescription error: %v", err) - } - - p.sendAck(uid) + p.handleSdpAnswer(answer, uid) } if cand, ok := msg["webrtcIceCandidate"].(map[string]interface{}); ok { @@ -512,6 +459,94 @@ func (p *Peer) handleSignaling() { } } +func (p *Peer) updateWSDeadline() { + p.wsMu.Lock() + if p.ws != nil { + _ = p.ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + } + p.wsMu.Unlock() +} + +func (p *Peer) handleCommonMessages(msg map[string]interface{}, uid string) { + if _, ok := msg["updateDescription"]; ok { + p.sendAck(uid) + } + if _, ok := msg["vadActivity"]; ok { + p.sendAck(uid) + } + if _, ok := msg["ping"]; ok { + p.sendPong(uid) + } + if _, ok := msg["pong"]; ok { + p.sendAck(uid) + } +} + +func (p *Peer) handleSdpOffer(offer map[string]interface{}, uid string) error { + sdp, _ := offer["sdp"].(string) + pcSeq, _ := offer["pcSeq"].(float64) + + if err := p.pcSub.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdp, + }); err != nil { + return fmt.Errorf("set remote desc: %w", err) + } + + answer, err := p.pcSub.CreateAnswer(nil) + if err != nil { + return fmt.Errorf("create answer: %w", err) + } + + if err := p.pcSub.SetLocalDescription(answer); err != nil { + return fmt.Errorf("set local desc: %w", err) + } + + p.wsMu.Lock() + _ = p.ws.WriteJSON(map[string]interface{}{ + "uid": uuid.New().String(), + "subscriberSdpAnswer": map[string]interface{}{ + "pcSeq": int(pcSeq), + "sdp": answer.SDP, + }, + }) + p.wsMu.Unlock() + + p.sendAck(uid) + time.Sleep(300 * time.Millisecond) + + pubOffer, err := p.pcPub.CreateOffer(nil) + if err != nil { + return fmt.Errorf("create pub offer: %w", err) + } + + if err := p.pcPub.SetLocalDescription(pubOffer); err != nil { + return fmt.Errorf("set local pub desc: %w", err) + } + + p.wsMu.Lock() + _ = p.ws.WriteJSON(map[string]interface{}{ + "uid": uuid.New().String(), + "publisherSdpOffer": map[string]interface{}{ + "pcSeq": 1, + "sdp": pubOffer.SDP, + }, + }) + p.wsMu.Unlock() + return nil +} + +func (p *Peer) handleSdpAnswer(answer map[string]interface{}, uid string) { + sdp, _ := answer["sdp"].(string) + if err := p.pcPub.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: sdp, + }); err != nil { + log.Printf("SetRemoteDescription error: %v", err) + } + p.sendAck(uid) +} + func (p *Peer) handleICE(cand map[string]interface{}) { candStr, _ := cand["candidate"].(string) target, _ := cand["target"].(string) @@ -529,10 +564,11 @@ func (p *Peer) handleICE(cand map[string]interface{}) { SDPMLineIndex: func() *uint16 { v := uint16(sdpMLineIndex); return &v }(), } - if target == "SUBSCRIBER" { - p.pcSub.AddICECandidate(init) - } else if target == "PUBLISHER" { - p.pcPub.AddICECandidate(init) + switch target { + case "SUBSCRIBER": + _ = p.pcSub.AddICECandidate(init) + case "PUBLISHER": + _ = p.pcPub.AddICECandidate(init) } } @@ -544,12 +580,10 @@ func (p *Peer) sendAck(uid string) { p.wsMu.Lock() defer p.wsMu.Unlock() - p.ws.WriteJSON(map[string]interface{}{ + _ = p.ws.WriteJSON(map[string]interface{}{ "uid": uid, "ack": map[string]interface{}{ - "status": map[string]interface{}{ - "code": "OK", - }, + "status": map[string]interface{}{"code": "OK"}, }, }) } @@ -573,9 +607,7 @@ func (p *Peer) waitForAck(uid string, ch <-chan struct{}, timeout time.Duration) return false } - defer func() { - p.removeAckWaiter(uid) - }() + defer p.removeAckWaiter(uid) select { case <-ch: @@ -605,13 +637,13 @@ func (p *Peer) sendPong(uid string) { p.wsMu.Lock() defer p.wsMu.Unlock() - p.ws.WriteJSON(map[string]interface{}{ + _ = p.ws.WriteJSON(map[string]interface{}{ "uid": uid, "pong": map[string]interface{}{}, }) } -func (p *Peer) startTelemetry(serverHello map[string]interface{}) { +func (p *Peer) startTelemetry(ctx context.Context, serverHello map[string]interface{}) { //nolint:cyclop cfg, ok := serverHello["telemetryConfiguration"].(map[string]interface{}) if !ok { return @@ -625,7 +657,7 @@ func (p *Peer) startTelemetry(serverHello map[string]interface{}) { endpoint, _ = cfg["url"].(string) } if endpoint == "" { - logger.Verbose("Telemetry configuration has no endpoint; skipping XHR simulation") + logger.Verbosef("Telemetry endpoint missing") return } @@ -646,16 +678,16 @@ func (p *Peer) startTelemetry(serverHello map[string]interface{}) { ticker := time.NewTicker(interval) defer ticker.Stop() - p.sendTelemetry(endpoint, "join") + p.sendTelemetry(ctx, endpoint, "join") for { select { case <-ticker.C: - p.sendTelemetry(endpoint, "stats") + p.sendTelemetry(ctx, endpoint, "stats") case <-p.telemetryCh: - p.sendTelemetry(endpoint, "leave") + p.sendTelemetry(ctx, endpoint, "leave") return case <-p.closeCh: - p.sendTelemetry(endpoint, "leave") + p.sendTelemetry(ctx, endpoint, "leave") return } } @@ -671,7 +703,7 @@ func (p *Peer) stopTelemetry() { } } -func (p *Peer) sendTelemetry(endpoint, event string) { +func (p *Peer) sendTelemetry(ctx context.Context, endpoint, event string) { body, err := json.Marshal(map[string]interface{}{ "event": event, "timestamp": time.Now().UnixMilli(), @@ -688,9 +720,9 @@ func (p *Peer) sendTelemetry(endpoint, event string) { return } - req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) if err != nil { - logger.Verbose("Telemetry request skipped: %v", err) + logger.Verbosef("Telemetry req error: %v", err) return } req.Header.Set("Content-Type", "application/json") @@ -705,13 +737,10 @@ func (p *Peer) sendTelemetry(endpoint, event string) { client := protect.NewHTTPClient() resp, err := client.Do(req) if err != nil { - logger.Verbose("Telemetry send failed: %v", err) + logger.Verbosef("Telemetry send error: %v", err) return } - defer resp.Body.Close() - if resp.StatusCode >= 400 { - logger.Verbose("Telemetry endpoint returned %s", resp.Status) - } + defer func() { _ = resp.Body.Close() }() } func (p *Peer) signalEnded(reason string) { @@ -759,10 +788,9 @@ func (p *Peer) setupICEHandlers() { if c == nil { return } - init := c.ToJSON() p.wsMu.Lock() - p.ws.WriteJSON(map[string]interface{}{ + _ = p.ws.WriteJSON(map[string]interface{}{ "uid": uuid.New().String(), "webrtcIceCandidate": map[string]interface{}{ "candidate": init.Candidate, @@ -779,10 +807,9 @@ func (p *Peer) setupICEHandlers() { if c == nil { return } - init := c.ToJSON() p.wsMu.Lock() - p.ws.WriteJSON(map[string]interface{}{ + _ = p.ws.WriteJSON(map[string]interface{}{ "uid": uuid.New().String(), "webrtcIceCandidate": map[string]interface{}{ "candidate": init.Candidate, @@ -801,7 +828,6 @@ func (p *Peer) sendLeave(uid string) bool { defer p.wsMu.Unlock() if p.ws == nil { - log.Println("WebSocket already closed, cannot send leave") return false } @@ -813,45 +839,30 @@ func (p *Peer) sendLeave(uid string) bool { if err := p.ws.WriteJSON(leave); err != nil { log.Printf("Failed to send leave: %v", err) return false - } else { - log.Println("Sent leave message to server") } + log.Println("Sent leave message") return true } +// Close closes the peer connection and cleans up resources. func (p *Peer) Close() error { - log.Println("Closing peer connection...") - + log.Println("Closing peer...") alreadyClosing := p.closed.Swap(true) p.sendQueueClosed.Store(true) if !alreadyClosing { - log.Println("Sending leave message...") leaveUID := uuid.New().String() leaveAck := p.registerAckWaiter(leaveUID) if p.sendLeave(leaveUID) { - if p.waitForAck(leaveUID, leaveAck, 1500*time.Millisecond) { - log.Println("Leave acknowledged") - } else { - log.Println("Leave ack timeout") - } + _ = p.waitForAck(leaveUID, leaveAck, 1500*time.Millisecond) } else { p.removeAckWaiter(leaveUID) } - p.stopTelemetry() } - log.Println("Closing channels...") - if p.closeCh != nil { - select { - case <-p.closeCh: - default: - close(p.closeCh) - } - } + closeSignal(p.closeCh) - log.Println("Waiting for goroutines...") done := make(chan struct{}) go func() { p.wg.Wait() @@ -860,72 +871,47 @@ func (p *Peer) Close() error { select { case <-done: - log.Println("All goroutines finished") case <-time.After(2 * time.Second): - log.Println("Goroutine wait timeout") + log.Println("Wait timeout") } if p.dc != nil { - log.Println("Closing DataChannel...") - p.dc.Close() + _ = p.dc.Close() } - if p.pcPub != nil { - log.Println("Closing Publisher PeerConnection...") - p.pcPub.Close() + _ = p.pcPub.Close() } - if p.pcSub != nil { - log.Println("Closing Subscriber PeerConnection...") - p.pcSub.Close() + _ = p.pcSub.Close() } - if p.ws != nil { - log.Println("Closing WebSocket...") p.wsMu.Lock() - p.ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) - p.ws.Close() + _ = p.ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + time.Now().Add(time.Second)) + _ = p.ws.Close() p.wsMu.Unlock() } - log.Println("Peer closed") return nil } func (p *Peer) keepAlive(keepAliveCh <-chan struct{}) { - wsPingTicker := time.NewTicker(30 * time.Second) - defer wsPingTicker.Stop() - - appPingTicker := time.NewTicker(5 * time.Second) - defer appPingTicker.Stop() + wsTicker := time.NewTicker(30 * time.Second) + defer wsTicker.Stop() + appTicker := time.NewTicker(5 * time.Second) + defer appTicker.Stop() for { select { - case <-wsPingTicker.C: - p.wsMu.Lock() - if p.ws != nil { - if err := p.ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { - log.Printf("WS Ping error: %v", err) - p.wsMu.Unlock() - p.queueReconnect() - return - } + case <-wsTicker.C: + if !p.sendWSPing() { + return } - p.wsMu.Unlock() - case <-appPingTicker.C: - p.wsMu.Lock() - if p.ws != nil { - if err := p.ws.WriteJSON(map[string]interface{}{ - "uid": uuid.New().String(), - "ping": map[string]interface{}{}, - }); err != nil { - log.Printf("App Ping error: %v", err) - p.wsMu.Unlock() - p.queueReconnect() - return - } + case <-appTicker.C: + if !p.sendAppPing() { + return } - p.wsMu.Unlock() case <-keepAliveCh: return case <-p.closeCh: @@ -934,40 +920,65 @@ func (p *Peer) keepAlive(keepAliveCh <-chan struct{}) { } } +func (p *Peer) sendWSPing() bool { + p.wsMu.Lock() + defer p.wsMu.Unlock() + if p.ws != nil { + if err := p.ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + log.Printf("WS Ping error: %v", err) + p.queueReconnect() + return false + } + } + return true +} + +func (p *Peer) sendAppPing() bool { + p.wsMu.Lock() + defer p.wsMu.Unlock() + if p.ws != nil { + if err := p.ws.WriteJSON(map[string]interface{}{ + "uid": uuid.New().String(), + "ping": map[string]interface{}{}, + }); err != nil { + log.Printf("App Ping error: %v", err) + p.queueReconnect() + return false + } + } + return true +} + func (p *Peer) reconnect(ctx context.Context) error { - log.Println("Reconnecting...") p.reconnecting.Store(true) defer p.reconnecting.Store(false) p.sendLeave(uuid.New().String()) time.Sleep(500 * time.Millisecond) - p.stopSession() if p.dc != nil { - p.dc.Close() + _ = p.dc.Close() } - if p.pcPub != nil { - p.pcPub.Close() + _ = p.pcPub.Close() } - if p.pcSub != nil { - p.pcSub.Close() + _ = p.pcSub.Close() } - if p.ws != nil { p.wsMu.Lock() - p.ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) - p.ws.Close() + _ = p.ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + time.Now().Add(time.Second)) + _ = p.ws.Close() p.wsMu.Unlock() } time.Sleep(3 * time.Second) - - conn, err := GetConnectionInfo(p.roomURL, p.name) + conn, err := GetConnectionInfo(ctx, p.roomURL, p.name) if err != nil { - return err + return fmt.Errorf("reconnect get info: %w", err) } p.conn = conn @@ -978,43 +989,41 @@ func (p *Peer) reconnect(ctx context.Context) error { if p.onReconnect != nil { p.onReconnect(p.dc) } - p.drainReconnectQueue() - return nil } -func (p *Peer) SetReconnectCallback(cb func(*webrtc.DataChannel)) { +func (p *Peer) SetReconnectCallback(cb func(*webrtc.DataChannel)) { //nolint:revive p.onReconnect = cb } -func (p *Peer) SetShouldReconnect(fn func() bool) { +func (p *Peer) SetShouldReconnect(fn func() bool) { //nolint:revive p.shouldReconnect = fn } -func (p *Peer) WatchConnection(ctx context.Context) { +func (p *Peer) WatchConnection(ctx context.Context) { //nolint:revive,cyclop const maxReconnects = 10 const reconnectWindow = 5 * time.Minute for { select { + case <-ctx.Done(): + return + case <-p.closeCh: + return case <-p.reconnectCh: - p.reconnectMu.Lock() - now := time.Now() - if now.Sub(p.lastReconnect) > reconnectWindow { + if time.Since(p.lastReconnect) > reconnectWindow { p.reconnectCount = 0 } + p.reconnectCount++ + p.lastReconnect = time.Now() - if p.reconnectCount >= maxReconnects { - log.Printf("Max reconnect attempts (%d) reached, stopping", maxReconnects) - p.reconnectMu.Unlock() + if p.reconnectCount > maxReconnects { + log.Printf("Max reconnects reached (%d)", maxReconnects) + p.signalEnded("reconnect limit reached") return } - p.reconnectCount++ - p.lastReconnect = now - p.reconnectMu.Unlock() - backoff := time.Duration(p.reconnectCount) * 2 * time.Second if backoff > 30*time.Second { backoff = 30 * time.Second @@ -1022,128 +1031,109 @@ func (p *Peer) WatchConnection(ctx context.Context) { for { if err := p.reconnect(ctx); err != nil { - log.Printf("Reconnect failed: %v, retrying in %v...", err, backoff) - time.Sleep(backoff) - continue + log.Printf("Reconnect failed: %v", err) + select { + case <-ctx.Done(): + return + case <-p.closeCh: + return + case <-time.After(backoff): + continue + } } - p.reconnectMu.Lock() - p.reconnectCount = 0 - p.reconnectMu.Unlock() - log.Println("Reconnected successfully") break } - case <-p.closeCh: - return - case <-ctx.Done(): - return } } } func (p *Peer) processSendQueue(workerID int, sessionCloseCh <-chan struct{}) { - for { select { - case data, ok := <-p.sendQueue: - if !ok { - return - } - if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { - continue - } - if p.trafficShape.MaxMessageSize > 0 && len(data) > p.trafficShape.MaxMessageSize { - log.Printf("[WORKER-%d] Refusing oversized DataChannel message size=%d limit=%d", workerID, len(data), p.trafficShape.MaxMessageSize) - continue - } - if delay := p.nextSendDelay(); delay > 0 { - time.Sleep(delay) - } - - // Wait until SCTP buffer drains. Dropping here would corrupt the - // carried TCP streams (the mux is a reliable transport); large - // downloads like Instagram/Twitter assets would hang forever - // waiting for the missing bytes. Backpressure already propagates - // upstream via CanSend() / the sendQueue length. - // Threshold is high (4MB) because a tight limit serialises sends: - // workers would pause on every frame, turning throughput into - // one chunk per 10ms drain cycle (~400KB/s). - waitStart := time.Now() - for p.dc.BufferedAmount() > 4*1024*1024 { - if p.dc.ReadyState() != webrtc.DataChannelStateOpen { - break - } - time.Sleep(10 * time.Millisecond) - } - if waited := time.Since(waitStart); waited > 500*time.Millisecond { - logger.Verbose("[WORKER-%d] Buffer drained after %v", workerID, waited) - } - - if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { - continue - } - - sendStart := time.Now() - if err := p.dc.Send(data); err != nil { - log.Printf("[WORKER-%d] Send error: %v", workerID, err) - } else { - elapsed := time.Since(sendStart) - if elapsed > 50*time.Millisecond { - log.Printf("[WORKER-%d] Sent %d bytes in %v (buffered: %d)", - workerID, len(data), elapsed, p.dc.BufferedAmount()) - } else { - logger.Verbose("[WORKER-%d] Sent %d bytes (buffered: %d)", - workerID, len(data), p.dc.BufferedAmount()) - } - } - case <-sessionCloseCh: return case <-p.closeCh: return + case data := <-p.sendQueue: + if len(data) > p.trafficShape.MaxMessageSize { + log.Printf("[WORKER-%d] Refusing oversized message size=%d limit=%d", + workerID, len(data), p.trafficShape.MaxMessageSize) + continue + } + + waited, err := p.waitBufferedAmount(workerID, sessionCloseCh) + if err != nil { + return + } + if waited > 0 { + logger.Verbosef("[WORKER-%d] Drained after %v", workerID, waited) + } + + if err := p.dc.Send(data); err != nil { + log.Printf("[WORKER-%d] Send error: %v", workerID, err) + p.queueReconnect() + return + } + + if p.trafficShape.MinDelay > 0 { + time.Sleep(p.calculateDelay()) + } } } } +func (p *Peer) waitBufferedAmount(workerID int, sessionCloseCh <-chan struct{}) (time.Duration, error) { + start := time.Now() + for p.dc.BufferedAmount() > 512*1024 { + select { + case <-sessionCloseCh: + return 0, ErrSessionClosed + case <-p.closeCh: + return 0, ErrPeerClosed + case <-time.After(10 * time.Millisecond): + if time.Since(start) > 5*time.Second { + log.Printf("[WORKER-%d] Buffer wait timeout", workerID) + return time.Since(start), nil + } + } + } + return time.Since(start), nil +} + +func (p *Peer) calculateDelay() time.Duration { + minDelay := p.trafficShape.MinDelay + maxDelay := p.trafficShape.MaxDelay + if maxDelay <= minDelay { + return minDelay + } + //nolint:gosec + return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) +} + func (p *Peer) monitorQueue(sessionCloseCh <-chan struct{}) { - ticker := time.NewTicker(3 * time.Second) + ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { - case <-ticker.C: - queueLen := len(p.sendQueue) - buffered := uint64(0) - if p.dc != nil { - buffered = p.dc.BufferedAmount() - } - if queueLen > 800 || buffered > 3*1024*1024 { - log.Printf("[QUEUE_MONITOR] queue_len=%d dc_buffered=%d MB", queueLen, buffered/(1024*1024)) - } case <-sessionCloseCh: return case <-p.closeCh: return + case <-ticker.C: + queueLen := len(p.sendQueue) + buffered := p.dc.BufferedAmount() + if queueLen > 100 || buffered > 1024*1024 { + log.Printf("[MONITOR] queue=%d, buffered=%d MB", + queueLen, buffered/(1024*1024)) + } } } } -func (p *Peer) CanSend() bool { - queueLen := len(p.sendQueue) - buffered := uint64(0) - if p.dc != nil { - buffered = p.dc.BufferedAmount() +func (p *Peer) CanSend() bool { //nolint:revive + if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { + return false } - return queueLen < 1000 && buffered < 3*1024*1024 -} - -func (p *Peer) nextSendDelay() time.Duration { - minDelay := p.trafficShape.MinDelay - maxDelay := p.trafficShape.MaxDelay - if maxDelay <= 0 { - return 0 - } - if maxDelay <= minDelay { - return maxDelay - } - return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) + return len(p.sendQueue) < 4000 }