From 513e2bdd9dc5c0c69b78d1878b7d58103acf9fa0 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Tue, 21 Apr 2026 01:51:48 +0300 Subject: [PATCH] feat: refactor client connection handling and error management --- internal/client/client.go | 418 ++++++++----------- internal/link/link.go | 1 + internal/server/server.go | 1 + internal/transport/videochannel/transport.go | 6 +- 4 files changed, 177 insertions(+), 249 deletions(-) diff --git a/internal/client/client.go b/internal/client/client.go index 9eadd0d..d8e4f44 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -12,7 +12,6 @@ import ( "io" "net" "sync" - "sync/atomic" "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" @@ -23,36 +22,25 @@ import ( ) var ( - // ErrKeySize is returned when the key size is not 32 bytes. - ErrKeySize = errors.New("key must be 32 bytes") - // ErrKeyStringLength is returned when the key string length is not 32. - ErrKeyStringLength = errors.New("key string length must be 32") - // ErrInvalidSocks5 is returned when the SOCKS version is not 5. - ErrInvalidSocks5 = errors.New("invalid SOCKS5 version") - // ErrNoLinks is returned when no links are available for sending. - ErrNoLinks = errors.New("no links available") - // ErrEncryptFailed is returned when encryption fails. - ErrEncryptFailed = errors.New("encrypt failed") - // ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported. - ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command") - // ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported. - ErrUnsupportedAddressType = errors.New("unsupported address type") - // ErrTunnelSetupFailed is returned when the tunnel cannot be established. - ErrTunnelSetupFailed = errors.New("tunnel setup failed") + // ErrConnectFailed is returned when a tunnel connection fails. + ErrConnectFailed = errors.New("tunnel connection failed") + // ErrProxyAuth is returned when SOCKS proxy authentication fails. + ErrProxyAuth = errors.New("SOCKS proxy auth failed") + // ErrMuxExited is returned when the multiplexer loop exits unexpectedly. + ErrMuxExited = errors.New("multiplexer loop exited") + // ErrNoAvailableLinks is returned when no links are ready for sending. + ErrNoAvailableLinks = errors.New("no available links") ) -// Client handles local SOCKS5 connections and tunnels them through the selected runtime stack. +// Client handles local SOCKS5 connections and tunnels them to the server. type Client struct { - links []link.Link - cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - linkIdx atomic.Uint32 - clientID uint32 - activeClients atomic.Int32 - wg sync.WaitGroup - dnsServer string + links []link.Link + cipher *crypto.Cipher + mux *mux.Multiplexer + connections map[uint16]net.Conn + connMu sync.RWMutex + clientID uint32 + dnsServer string } // Run starts the client with the specified parameters. @@ -71,8 +59,9 @@ func Run( videoHeight int, videoFPS int, videoBitrate string, + videoHW string, ) error { - return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate) + return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. @@ -92,6 +81,7 @@ func RunWithReady( videoHeight int, videoFPS int, videoBitrate string, + videoHW string, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -119,17 +109,17 @@ func RunWithReady( const linkCount = 1 for i := range linkCount { - if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate); err != nil { + if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW); err != nil { return fmt.Errorf("addLink failed: %w", err) } } lc := net.ListenConfig{} - ln, err := lc.Listen(runCtx, "tcp", localAddr) + ln, err := lc.Listen(runCtx, "tcp4", localAddr) if err != nil { - return fmt.Errorf("listen failed: %w", err) + return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } - defer func() { _ = ln.Close() }() + defer ln.Close() logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID) @@ -137,13 +127,17 @@ func RunWithReady( onReady() } - go c.acceptLoop(runCtx, ln) + errCh := make(chan error, 1) + go func() { + errCh <- c.acceptLoop(runCtx, ln) + }() - <-runCtx.Done() - c.shutdown() - c.wg.Wait() - - return nil + select { + case <-runCtx.Done(): + return nil + case err := <-errCh: + return err + } } func setupCipher(keyHex string) (*crypto.Cipher, error) { @@ -152,15 +146,10 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) { return nil, fmt.Errorf("failed to decode key: %w", err) } if len(key) != 32 { - return nil, ErrKeySize + return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key)) } - keyStr := string(key) - if len(keyStr) != 32 { - return nil, ErrKeyStringLength - } - - cipher, err := crypto.NewCipher(keyStr) + cipher, err := crypto.NewCipher(string(key)) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } @@ -185,13 +174,12 @@ func (c *Client) setupMux() { encrypted, err := c.cipher.Encrypt(frame) if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + return err } if len(c.links) == 0 { - return ErrNoLinks + return ErrNoAvailableLinks } - idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec - return c.links[idx].Send(encrypted) + return c.links[0].Send(encrypted) }) } @@ -207,7 +195,7 @@ func (c *Client) addLink( socksProxyAddr string, socksProxyPort int, videoWidth, videoHeight, videoFPS int, - videoBitrate string, + videoBitrate, videoHW string, ) error { ln, err := link.New(ctx, linkName, link.Config{ Transport: transportName, @@ -222,6 +210,7 @@ func (c *Client) addLink( VideoHeight: videoHeight, VideoFPS: videoFPS, VideoBitrate: videoBitrate, + VideoHW: videoHW, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -237,25 +226,17 @@ func (c *Client) addLink( c.handleLinkReconnect(linkID) }) - logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName) if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } - logger.Infof("Link %d connected", linkID) - - c.wg.Add(1) - go func() { - defer c.wg.Done() - ln.WatchConnection(ctx) - }() - - c.sendClientResetAsync("initial") + go ln.WatchConnection(ctx) return nil } func (c *Client) handleLinkReconnect(linkID int) { logger.Infof("link %d reconnect event", linkID) + c.sendResetSignal() c.connMu.Lock() for sid, conn := range c.connections { @@ -269,228 +250,176 @@ func (c *Client) handleLinkReconnect(linkID int) { c.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := c.cipher.Encrypt(frame) if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + return err } if len(c.links) == 0 { - return ErrNoLinks + return ErrNoAvailableLinks } - idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec - return c.links[idx].Send(encrypted) + return c.links[0].Send(encrypted) }) c.mux.Reset() - - c.sendClientResetAsync("reconnect") } -func (c *Client) sendClientResetAsync(source string) { - c.wg.Add(1) - go func() { - defer c.wg.Done() - if err := c.mux.SendClientReset(); err != nil { - logger.Warnf("Failed to send client reset after %s: %v", source, err) - } - }() -} - -func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) { - for { - select { - case <-ctx.Done(): - return - default: - conn, err := ln.Accept() - if err != nil { - logger.Debugf("Accept error: %v", err) - continue - } - go c.handleSOCKS5(ctx, conn) - } +func (c *Client) sendResetSignal() { + resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient) + encrypted, _ := c.cipher.Encrypt(resetFrame) + if len(c.links) > 0 { + _ = c.links[0].Send(encrypted) } } -func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) { - defer func() { _ = conn.Close() }() +func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-ctx.Done(): + return nil + default: + logger.Warnf("Accept error: %v", err) + continue + } + } + go c.handleSocks5(ctx, conn) + } +} + +func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { + defer conn.Close() if err := c.socks5Handshake(conn); err != nil { - logger.Debugf("SOCKS5 handshake failed: %v", err) return } - addr, port, err := c.socks5Request(conn) + targetAddr, targetPort, err := c.socks5Request(conn) if err != nil { - logger.Debugf("SOCKS5 request failed: %v", err) return } sid := c.mux.OpenStream() + defer c.mux.CloseStream(sid) + c.connMu.Lock() c.connections[sid] = conn c.connMu.Unlock() + defer func() { + c.connMu.Lock() + delete(c.connections, sid) + c.connMu.Unlock() + }() - logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port) + logger.Infof("sid=%d tunnel to %s:%d", sid, targetAddr, targetPort) - if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil { + connectReq, _ := json.Marshal(map[string]any{ + "cmd": "connect", + "addr": targetAddr, + "port": targetPort, + }) + + if err := c.mux.SendData(sid, connectReq); err != nil { logger.Warnf("sid=%d tunnel setup failed: %v", sid, err) + _, _ = conn.Write(replyHostUnreachable()) return } - c.activeClients.Add(1) - c.startStreamPump(ctx, sid, conn) - c.pumpToMux(sid, conn) -} - -func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error { - req := map[string]any{"cmd": "connect", "addr": addr, "port": port} - reqData, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("marshal connect: %w", err) - } - - if err := c.mux.SendData(sid, reqData); err != nil { - return fmt.Errorf("send connect: %w", err) - } + readyTimer := time.NewTimer(10 * time.Second) + defer readyTimer.Stop() dataReady := c.mux.WaitForData(sid) + + var initialData []byte select { - case <-dataReady: - resp := c.mux.ReadStream(sid) - if len(resp) > 0 && resp[0] == 0x00 { - if _, err := conn.Write(replySuccess()); err != nil { - return fmt.Errorf("write success: %w", err) - } - } else { - _, _ = conn.Write(replyHostUnreachable()) - return ErrTunnelSetupFailed - } - case <-time.After(15 * time.Second): + case <-readyTimer.C: + logger.Warnf("sid=%d tunnel setup failed: timeout waiting for remote ready", sid) _, _ = conn.Write(replyHostUnreachable()) - c.mux.CleanupDataChannel(sid) - return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed) - case <-ctx.Done(): - return fmt.Errorf("context cancelled: %w", ctx.Err()) + return + case <-dataReady: + initialData = c.mux.ReadStream(sid) + if len(initialData) == 0 || initialData[0] != 0x00 { + logger.Warnf("sid=%d tunnel setup failed: invalid remote ready", sid) + _, _ = conn.Write(replyHostUnreachable()) + return + } } - c.mux.CleanupDataChannel(sid) - return nil + + if _, err := conn.Write(replySuccess()); err != nil { + return + } + + // Handle the rest of initialData if any (unlikely for 0x00 packet) + if len(initialData) > 1 { + if _, err := conn.Write(initialData[1:]); err != nil { + return + } + } + + go c.pumpFromMux(ctx, sid, conn) + c.pumpToMux(sid, conn) } func (c *Client) socks5Handshake(conn net.Conn) error { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return fmt.Errorf("read header: %w", err) + return err } - if buf[0] != 5 { - return ErrInvalidSocks5 + return fmt.Errorf("invalid socks version: %d", buf[0]) } - - methods := make([]byte, int(buf[1])) + methods := make([]byte, buf[1]) if _, err := io.ReadFull(conn, methods); err != nil { - return fmt.Errorf("read methods: %w", err) + return err } - if _, err := conn.Write([]byte{5, 0}); err != nil { - return fmt.Errorf("write response: %w", err) + return err } return nil } func (c *Client) socks5Request(conn net.Conn) (string, int, error) { - buf := make([]byte, 4) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", 0, fmt.Errorf("read request header: %w", err) - } - - if buf[0] != 5 || buf[1] != 1 { - return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1]) - } - - addr, err := c.readSocks5Addr(conn, buf[3]) - if err != nil { + header := make([]byte, 4) + if _, err := io.ReadFull(conn, header); err != nil { return "", 0, err } + if header[1] != 1 { + return "", 0, fmt.Errorf("unsupported socks command: %d", header[1]) + } + + var addr string + switch header[3] { + case 1: // IPv4 + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", 0, err + } + addr = net.IP(buf).String() + case 3: // Domain + lenBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return "", 0, err + } + buf := make([]byte, lenBuf[0]) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", 0, err + } + addr = string(buf) + default: + return "", 0, fmt.Errorf("unsupported address type: %d", header[3]) + } portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { - return "", 0, fmt.Errorf("read port: %w", err) + return "", 0, err } port := int(binary.BigEndian.Uint16(portBuf)) return addr, port, nil } -func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) { - switch addrType { - case 1: // IPv4 - ip := make([]byte, 4) - if _, err := io.ReadFull(conn, ip); err != nil { - return "", fmt.Errorf("read ipv4: %w", err) - } - return net.IP(ip).String(), nil - case 3: // Domain - lenBuf := make([]byte, 1) - if _, err := io.ReadFull(conn, lenBuf); err != nil { - return "", fmt.Errorf("read domain len: %w", err) - } - domain := make([]byte, int(lenBuf[0])) - if _, err := io.ReadFull(conn, domain); err != nil { - return "", fmt.Errorf("read domain: %w", err) - } - return string(domain), nil - case 4: // IPv6 - ip := make([]byte, 16) - if _, err := io.ReadFull(conn, ip); err != nil { - return "", fmt.Errorf("read ipv6: %w", err) - } - return net.IP(ip).String(), nil - default: - return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType) - } -} - -func (c *Client) onData(data []byte) { - plaintext, err := c.cipher.Decrypt(data) - if err != nil { - logger.Debugf("Decrypt error: %v", err) - return - } - - c.mux.HandleFrame(plaintext) -} - -func (c *Client) shutdown() { - c.connMu.Lock() - for _, conn := range c.connections { - if conn != nil { - _ = conn.Close() - } - } - c.connMu.Unlock() - - for i, tr := range c.links { - logger.Infof("closing link %d", i) - _ = tr.Close() - } -} - func (c *Client) pumpToMux(sid uint16, conn net.Conn) { - defer func() { - c.activeClients.Add(-1) - _ = c.mux.CloseStream(sid) - c.connMu.Lock() - delete(c.connections, sid) - c.connMu.Unlock() - }() - buf := make([]byte, 16384) - totalSent := uint64(0) - lastLog := time.Now() - for { n, err := conn.Read(buf) if err != nil { - if totalSent > 1024*1024 { - logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024)) - } return } @@ -501,41 +430,36 @@ func (c *Client) pumpToMux(sid uint16, conn net.Conn) { if err := c.mux.SendData(sid, buf[:n]); err != nil { return } + } +} - totalSent += uint64(n) //nolint:gosec - if time.Since(lastLog) > 5*time.Second { - logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024)) - lastLog = time.Now() +func (c *Client) pumpFromMux(ctx context.Context, sid uint16, conn net.Conn) { + defer c.mux.CleanupDataChannel(sid) + dataReady := c.mux.WaitForData(sid) + for { + select { + case <-ctx.Done(): + return + case <-dataReady: + data := c.mux.ReadStream(sid) + if len(data) > 0 { + if _, err := conn.Write(data); err != nil { + return + } + } + if c.mux.StreamClosed(sid) { + return + } } } } -func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { - c.wg.Add(1) - go func() { - defer c.wg.Done() - - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - data := c.mux.ReadStream(sid) - if len(data) > 0 { - if _, err := conn.Write(data); err != nil { - _ = c.mux.CloseStream(sid) - return - } - } - if c.mux.StreamClosed(sid) { - return - } - } - } - }() +func (c *Client) onData(data []byte) { + plaintext, err := c.cipher.Decrypt(data) + if err != nil { + return + } + c.mux.HandleFrame(plaintext) } func (c *Client) canSendData() bool { diff --git a/internal/link/link.go b/internal/link/link.go index 8c02198..02c6bf4 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -37,6 +37,7 @@ type Config struct { VideoHeight int VideoFPS int VideoBitrate string + VideoHW string } // Factory creates a link instance. diff --git a/internal/server/server.go b/internal/server/server.go index c88bb8a..7bfed1e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -80,6 +80,7 @@ func Run( videoHeight int, videoFPS int, videoBitrate string, + videoHW string, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go index aa71457..626605c 100644 --- a/internal/transport/videochannel/transport.go +++ b/internal/transport/videochannel/transport.go @@ -61,6 +61,7 @@ type streamTransport struct { videoH int videoFPS int videoBitrate string + videoHW string } // New creates a visual videochannel transport backed by a carrier-specific provider. @@ -109,6 +110,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) videoH: cfg.VideoHeight, videoFPS: cfg.VideoFPS, videoBitrate: cfg.VideoBitrate, + videoHW: cfg.VideoHW, } if err := stream.AddTrack(track); err != nil { @@ -124,7 +126,7 @@ func (p *streamTransport) Connect(ctx context.Context) error { connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout) defer cancel() - encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate) + encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW) if err != nil { return err } @@ -328,7 +330,7 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc return } - decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS) + decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS, p.videoHW) if err != nil { logger.Warnf("videochannel decoder init failed: %v", err) return