// Package client implements the local SOCKS5 client side of the olcrtc tunnel. package client import ( "context" "encoding/binary" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "sync" "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" "github.com/xtaci/smux" ) var ( // ErrConnectFailed is returned when a tunnel connection fails. ErrConnectFailed = errors.New("tunnel connection failed") // ErrProxyAuth is returned when SOCKS proxy authentication fails. ErrProxyAuth = errors.New("SOCKS proxy auth failed") ) // Client handles local SOCKS5 connections and tunnels them to the server. type Client struct { ln link.Link cipher *crypto.Cipher conn *muxconn.Conn session *smux.Session sessMu sync.RWMutex dnsServer string } // Run starts the client with the specified parameters. func Run( ctx context.Context, linkName, transportName, carrierName, roomURL, keyHex string, localAddr string, dnsServer, socksUser string, socksPass string, videoWidth int, videoHeight int, videoFPS int, videoBitrate string, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, videoTileModule int, videoTileRS int, vp8FPS int, vp8BatchSize int, ) error { return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. func RunWithReady( ctx context.Context, linkName, transportName, carrierName, roomURL, keyHex string, localAddr string, dnsServer, _ string, _ string, onReady func(), videoWidth int, videoHeight int, videoFPS int, videoBitrate string, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, videoTileModule int, videoTileRS int, vp8FPS int, vp8BatchSize int, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() cipher, err := setupCipher(keyHex) if err != nil { return fmt.Errorf("setupCipher failed: %w", err) } c := &Client{cipher: cipher, dnsServer: dnsServer} if err := c.bringUpLink( runCtx, linkName, transportName, carrierName, roomURL, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize, ); err != nil { return err } defer c.shutdown() lc := net.ListenConfig{} listener, err := lc.Listen(runCtx, "tcp4", localAddr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } defer listener.Close() logger.Infof("SOCKS5 server listening on %s", localAddr) if onReady != nil { onReady() } errCh := make(chan error, 1) go func() { errCh <- c.acceptLoop(runCtx, listener) }() select { case <-runCtx.Done(): return nil case err := <-errCh: return err } } func (c *Client) bringUpLink( ctx context.Context, linkName, transportName, carrierName, roomURL string, cancel context.CancelFunc, dnsServer, socksProxyAddr string, socksProxyPort int, videoWidth, videoHeight, videoFPS int, videoBitrate, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, videoTileModule, videoTileRS int, vp8FPS, vp8BatchSize int, ) error { ln, err := link.New(ctx, linkName, link.Config{ Transport: transportName, Carrier: carrierName, RoomURL: roomURL, Name: names.Generate(), OnData: c.onData, DNSServer: dnsServer, ProxyAddr: socksProxyAddr, ProxyPort: socksProxyPort, VideoWidth: videoWidth, VideoHeight: videoHeight, VideoFPS: videoFPS, VideoBitrate: videoBitrate, VideoHW: videoHW, VideoQRSize: videoQRSize, VideoQRRecovery: videoQRRecovery, VideoCodec: videoCodec, VideoTileModule: videoTileModule, VideoTileRS: videoTileRS, VP8FPS: vp8FPS, VP8BatchSize: vp8BatchSize, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) } c.ln = ln ln.SetEndedCallback(func(reason string) { logger.Infof("Client link reported conference end: %s", reason) cancel() }) ln.SetReconnectCallback(func() { c.handleReconnect() }) if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } c.conn = muxconn.New(ln, c.cipher) sess, err := smux.Client(c.conn, smuxConfig()) if err != nil { return fmt.Errorf("smux client: %w", err) } c.sessMu.Lock() c.session = sess c.sessMu.Unlock() go ln.WatchConnection(ctx) return nil } // smuxConfig returns the tuned smux config used on both ends. func smuxConfig() *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 cfg.MaxFrameSize = 32768 cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second cfg.KeepAliveTimeout = 60 * time.Second return cfg } func (c *Client) handleReconnect() { logger.Infof("client link reconnect — tearing down smux session") c.sessMu.Lock() if c.session != nil { _ = c.session.Close() c.session = nil } if c.conn != nil { _ = c.conn.Close() c.conn = nil } c.sessMu.Unlock() // New SOCKS5 connections will fail until the link comes back up; the // caller will reissue them. Existing streams die with the smux session. c.conn = muxconn.New(c.ln, c.cipher) sess, err := smux.Client(c.conn, smuxConfig()) if err != nil { logger.Warnf("smux re-init failed: %v", err) return } c.sessMu.Lock() c.session = sess c.sessMu.Unlock() } func (c *Client) shutdown() { c.sessMu.Lock() if c.session != nil { _ = c.session.Close() } if c.conn != nil { _ = c.conn.Close() } c.sessMu.Unlock() if c.ln != nil { _ = c.ln.Close() } } func setupCipher(keyHex string) (*crypto.Cipher, error) { key, err := hex.DecodeString(keyHex) if err != nil { return nil, fmt.Errorf("failed to decode key: %w", err) } if len(key) != 32 { return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key)) } cipher, err := crypto.NewCipher(string(key)) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } return cipher, nil } func (c *Client) onData(data []byte) { c.sessMu.RLock() conn := c.conn c.sessMu.RUnlock() if conn != nil { conn.Push(data) } } func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error { for { conn, err := ln.Accept() if err != nil { select { case <-ctx.Done(): return nil default: logger.Warnf("Accept error: %v", err) continue } } go c.handleSocks5(ctx, conn) } } func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { defer conn.Close() if err := c.socks5Handshake(conn); err != nil { return } targetAddr, targetPort, err := c.socks5Request(conn) if err != nil { return } c.sessMu.RLock() sess := c.session c.sessMu.RUnlock() if sess == nil || sess.IsClosed() { _, _ = conn.Write(replyHostUnreachable()) return } stream, err := sess.OpenStream() if err != nil { logger.Warnf("OpenStream failed: %v", err) _, _ = conn.Write(replyHostUnreachable()) return } defer stream.Close() logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort) connectReq, _ := json.Marshal(map[string]any{ "cmd": "connect", "addr": targetAddr, "port": targetPort, }) _ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := stream.Write(connectReq); err != nil { logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err) _, _ = conn.Write(replyHostUnreachable()) return } _ = stream.SetWriteDeadline(time.Time{}) ack := make([]byte, 1) _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 { logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack) _, _ = conn.Write(replyHostUnreachable()) return } _ = stream.SetReadDeadline(time.Time{}) if _, err := conn.Write(replySuccess()); err != nil { return } go func() { _, _ = io.Copy(stream, conn) _ = stream.Close() }() _, _ = io.Copy(conn, stream) _ = ctx // keep signature } func (c *Client) socks5Handshake(conn net.Conn) error { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { return err } if buf[0] != 5 { return fmt.Errorf("invalid socks version: %d", buf[0]) } methods := make([]byte, buf[1]) if _, err := io.ReadFull(conn, methods); err != nil { return err } if _, err := conn.Write([]byte{5, 0}); err != nil { return err } return nil } func (c *Client) socks5Request(conn net.Conn) (string, int, error) { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return "", 0, err } if header[1] != 1 { return "", 0, fmt.Errorf("unsupported socks command: %d", header[1]) } var addr string switch header[3] { case 1: // IPv4 buf := make([]byte, 4) if _, err := io.ReadFull(conn, buf); err != nil { return "", 0, err } addr = net.IP(buf).String() case 3: // Domain lenBuf := make([]byte, 1) if _, err := io.ReadFull(conn, lenBuf); err != nil { return "", 0, err } buf := make([]byte, lenBuf[0]) if _, err := io.ReadFull(conn, buf); err != nil { return "", 0, err } addr = string(buf) default: return "", 0, fmt.Errorf("unsupported address type: %d", header[3]) } portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { return "", 0, err } port := int(binary.BigEndian.Uint16(portBuf)) return addr, port, nil } func replySuccess() []byte { return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0} } func replyHostUnreachable() []byte { return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0} }