diff --git a/internal/client/client.go b/internal/client/client.go index 9712bea..7a23454 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "log" "net" - "strconv" "sync" "sync/atomic" "time" @@ -26,232 +24,358 @@ import ( ) var ( - errInvalidKeyLength = errors.New("key must be 32 bytes") - errInvalidKeyStringLength = errors.New("key string length must be 32") - errNoConnectedPeers = errors.New("no connected peers available") + ErrKeySize = errors.New("key must be 32 bytes") + ErrKeyStringLength = errors.New("key string length must be 32") + ErrInvalidSocks5 = errors.New("invalid SOCKS5 version") + ErrNoPeers = errors.New("no peers available") + ErrEncryptFailed = errors.New("encrypt failed") ) -// Client handles local SOCKS5 connections and tunnels them through WebRTC. +// Client handles local SOCKS5 connections and tunnels them via WebRTC. type Client struct { - peers []provider.Provider - cipher *crypto.Cipher - mux *mux.Multiplexer - clientID uint32 - peerIdx atomic.Uint32 - wg sync.WaitGroup + peers []provider.Provider + cipher *crypto.Cipher + mux *mux.Multiplexer + connections map[uint16]net.Conn + connMu sync.RWMutex + peerIdx atomic.Uint32 + clientID uint32 + activeClients atomic.Int32 + wg sync.WaitGroup + dnsServer string } -const defaultSOCKSListenHost = "127.0.0.1" - // Run starts the client with the specified parameters. func Run( ctx context.Context, providerName, roomURL, keyHex string, - socksPort int, - socksHost, - socksUser, - socksPass string, -) error { - return RunWithReady(ctx, providerName, roomURL, keyHex, socksPort, socksHost, socksUser, socksPass, nil) -} - -// RunWithReady starts the client and calls onReady when it is listening. -func RunWithReady( - ctx context.Context, - providerName, - roomURL, - keyHex string, - socksPort int, - socksHost, - socksUser, - socksPass string, - onReady func(), + localAddr string, + dnsServer, + socksProxyAddr string, + socksProxyPort int, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() - key, err := decodeKey(keyHex) + cipher, err := setupCipher(keyHex) if err != nil { - return fmt.Errorf("decodeKey failed: %w", err) + return fmt.Errorf("setupCipher failed: %w", err) } - keyStr := string(key) - if len(keyStr) != 32 { - return fmt.Errorf("%w: got %d", errInvalidKeyStringLength, len(keyStr)) - } - - cipher, err := crypto.NewCipher(keyStr) - if err != nil { - return fmt.Errorf("create cipher: %w", err) + clientIDBytes := make([]byte, 4) + if _, err := rand.Read(clientIDBytes); err != nil { + return fmt.Errorf("failed to generate client ID: %w", err) } + clientID := binary.BigEndian.Uint32(clientIDBytes) c := &Client{ - cipher: cipher, - clientID: uint32(time.Now().UnixNano() & 0xFFFFFFFF), - peers: make([]provider.Provider, 0, 1), + cipher: cipher, + connections: make(map[uint16]net.Conn), + peers: make([]provider.Provider, 0), + clientID: clientID, + dnsServer: dnsServer, } - c.mux = mux.New(c.clientID, c.sendFrame) + c.setupMux() - for peerID := range 1 { - if err := c.addPeer(runCtx, providerName, roomURL, peerID, cancel); err != nil { + const peerCount = 1 + for i := range peerCount { + if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil { return fmt.Errorf("addPeer failed: %w", err) } } - time.Sleep(100 * time.Millisecond) - c.sendResetSignal() + ln, err := net.Listen("tcp", localAddr) + if err != nil { + return fmt.Errorf("listen failed: %w", err) + } + defer ln.Close() - err = c.runSOCKS5(runCtx, socksHost, socksPort, socksUser, socksPass, onReady) + logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID) + go c.acceptLoop(runCtx, ln) + + <-runCtx.Done() + c.shutdown() c.wg.Wait() - return err -} - -func decodeKey(keyHex string) ([]byte, error) { - if keyHex == "" { - key := make([]byte, 32) - if _, err := rand.Read(key); err != nil { - return nil, fmt.Errorf("generate random key: %w", err) - } - - log.Printf("Generated key: %x", key) - return key, nil - } - - key, err := hex.DecodeString(keyHex) - if err != nil { - return nil, fmt.Errorf("decode hex key: %w", err) - } - - if len(key) != 32 { - return nil, fmt.Errorf("%w: got %d", errInvalidKeyLength, len(key)) - } - - return key, nil -} - -func (c *Client) sendFrame(frame []byte) error { - waitUntilPeersCanSend(c.peers) - - encrypted, err := c.cipher.Encrypt(frame) - if err != nil { - return fmt.Errorf("encrypt outgoing frame: %w", err) - } - - peer, err := c.nextPeer() - if err != nil { - return err - } - - if err := peer.Send(encrypted); err != nil { - return fmt.Errorf("send frame via peer: %w", err) - } - return nil } -func waitUntilPeersCanSend(peers []provider.Provider) { - for { - canSend := true - for _, peer := range peers { - if !peer.CanSend() { - canSend = false - break - } - } - - if canSend { - return - } - - time.Sleep(10 * time.Millisecond) +func setupCipher(keyHex string) (*crypto.Cipher, error) { + key, err := hex.DecodeString(keyHex) + if err != nil { + return nil, fmt.Errorf("failed to decode key: %w", err) } + if len(key) != 32 { + return nil, ErrKeySize + } + + keyStr := string(key) + if len(keyStr) != 32 { + return nil, ErrKeyStringLength + } + + cipher, err := crypto.NewCipher(keyStr) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + return cipher, nil } -// nextPeer returns the next provider for load balancing. -// -//nolint:ireturn -func (c *Client) nextPeer() (provider.Provider, error) { - switch len(c.peers) { - case 0: - return nil, errNoConnectedPeers - case 1: - return c.peers[0], nil - default: - return c.peers[int(c.peerIdx.Add(1)%2)], nil - } +func (c *Client) setupMux() { + c.mux = mux.New(c.clientID, func(frame []byte) error { + for { + canSend := true + for _, peer := range c.peers { + if !peer.CanSend() { + canSend = false + break + } + } + if canSend { + break + } + time.Sleep(10 * time.Millisecond) + } + + encrypted, err := c.cipher.Encrypt(frame) + if err != nil { + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + } + if len(c.peers) == 0 { + return ErrNoPeers + } + idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec + return c.peers[idx].Send(encrypted) + }) } func (c *Client) addPeer( - runCtx context.Context, + ctx context.Context, providerName, roomURL string, peerID int, cancel context.CancelFunc, + dnsServer, + socksProxyAddr string, + socksProxyPort int, ) error { - peer, err := provider.New(runCtx, providerName, provider.Config{ - RoomURL: roomURL, - Name: names.Generate(), - OnData: c.onData, + peer, err := provider.New(ctx, providerName, provider.Config{ + RoomURL: roomURL, + Name: names.Generate(), + OnData: c.onData, + DNSServer: dnsServer, + ProxyAddr: socksProxyAddr, + ProxyPort: socksProxyPort, }) if err != nil { - return fmt.Errorf("create peer %d: %w", peerID, err) + return fmt.Errorf("failed to create peer: %w", err) } peer.SetEndedCallback(func(reason string) { - log.Printf("Client peer %d reported conference end: %s", peerID, reason) + logger.Infof("Client peer %d reported conference end: %s", peerID, reason) cancel() }) - - peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - c.onReconnect(peerID, dc) - }) - c.peers = append(c.peers, peer) - log.Printf("Connecting peer %d to %s...", peerID, providerName) - if err := peer.Connect(runCtx); err != nil { - return fmt.Errorf("connect peer %d: %w", peerID, err) + peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { + c.handlePeerReconnect(peerID, dc) + }) + + logger.Infof("Connecting peer %d to %s...", peerID, providerName) + if err := peer.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect peer: %w", err) } - log.Printf("Peer %d connected", peerID) + logger.Infof("Peer %d connected", peerID) c.wg.Add(1) go func() { defer c.wg.Done() - peer.WatchConnection(runCtx) + peer.WatchConnection(ctx) }() + // Send initial reset to clean up any stale connections for this clientID on server + if err := c.mux.SendClientReset(); err != nil { + logger.Warnf("Failed to send initial client reset: %v", err) + } + return nil } -func (c *Client) onReconnect(peerID int, dc *webrtc.DataChannel) { - log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil) +func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { + logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil) + + c.connMu.Lock() + for sid, conn := range c.connections { + if conn != nil { + _ = conn.Close() + } + delete(c.connections, sid) + } + c.connMu.Unlock() if dc != nil { - c.mux.UpdateSendFunc(c.sendFrame) + c.mux.UpdateSendFunc(func(frame []byte) error { + encrypted, err := c.cipher.Encrypt(frame) + if err != nil { + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + } + if len(c.peers) == 0 { + return ErrNoPeers + } + idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec + return c.peers[idx].Send(encrypted) + }) c.mux.Reset() + + if err := c.mux.SendClientReset(); err != nil { + logger.Warnf("Failed to send client reset after reconnect: %v", err) + } } } -func (c *Client) sendResetSignal() { - resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient) - encrypted, err := c.cipher.Encrypt(resetFrame) - if err != nil { - log.Printf("Failed to encrypt reset signal: %v", err) +func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) { + for { + select { + case <-ctx.Done(): + return + default: + conn, err := ln.Accept() + if err != nil { + logger.Debugf("Accept error: %v", err) + continue + } + go c.handleSOCKS5(ctx, conn) + } + } +} + +func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) { + defer conn.Close() + + if err := c.socks5Handshake(conn); err != nil { + logger.Debugf("SOCKS5 handshake failed: %v", err) return } - for _, peer := range c.peers { - if err := peer.Send(encrypted); err != nil { - log.Printf("Failed to send reset signal to server: %v", err) - } + addr, port, err := c.socks5Request(conn) + if err != nil { + logger.Debugf("SOCKS5 request failed: %v", err) + return } - log.Printf("Sent reset signal to server (clientID=%d)", c.clientID) + sid := c.mux.OpenStream() + c.connMu.Lock() + c.connections[sid] = conn + c.connMu.Unlock() + + logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port) + + req := map[string]any{ + "cmd": "connect", + "addr": addr, + "port": port, + } + reqData, _ := json.Marshal(req) + + if err := c.mux.SendData(sid, reqData); err != nil { + logger.Warnf("sid=%d send connect failed: %v", sid, err) + return + } + + dataReady := c.mux.WaitForData(sid) + select { + case <-dataReady: + resp := c.mux.ReadStream(sid) + if len(resp) > 0 && resp[0] == 0x00 { + if _, err := conn.Write(replySuccess()); err != nil { + return + } + } else { + _, _ = conn.Write(replyHostUnreachable()) + return + } + case <-time.After(15 * time.Second): + _, _ = conn.Write(replyHostUnreachable()) + c.mux.CleanupDataChannel(sid) + return + case <-ctx.Done(): + return + } + c.mux.CleanupDataChannel(sid) + + c.activeClients.Add(1) + c.startStreamPump(ctx, sid, conn) + c.pumpToMux(sid, conn) +} + +func (c *Client) socks5Handshake(conn net.Conn) error { + buf := make([]byte, 2) + if _, err := io.ReadFull(conn, buf); err != nil { + return err + } + + if buf[0] != 5 { + return ErrInvalidSocks5 + } + + methods := make([]byte, int(buf[1])) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + _, err := conn.Write([]byte{5, 0}) + return err +} + +func (c *Client) socks5Request(conn net.Conn) (string, int, error) { + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", 0, err + } + + if buf[0] != 5 || buf[1] != 1 { + return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1]) + } + + var addr string + switch buf[3] { + case 1: // IPv4 + ip := make([]byte, 4) + if _, err := io.ReadFull(conn, ip); err != nil { + return "", 0, err + } + addr = net.IP(ip).String() + case 3: // Domain + lenBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return "", 0, err + } + domain := make([]byte, int(lenBuf[0])) + if _, err := io.ReadFull(conn, domain); err != nil { + return "", 0, err + } + addr = string(domain) + case 4: // IPv6 + ip := make([]byte, 16) + if _, err := io.ReadFull(conn, ip); err != nil { + return "", 0, err + } + addr = net.IP(ip).String() + default: + return "", 0, fmt.Errorf("unsupported address type: %d", buf[3]) + } + + portBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", 0, err + } + port := int(binary.BigEndian.Uint16(portBuf)) + + return addr, port, nil } func (c *Client) onData(data []byte) { @@ -264,347 +388,100 @@ func (c *Client) onData(data []byte) { c.mux.HandleFrame(plaintext) } -func (c *Client) runSOCKS5( - ctx context.Context, - host string, - port int, - username, - password string, - onReady func(), -) error { - if host == "" { - host = defaultSOCKSListenHost - } - - listenAddr := net.JoinHostPort(host, strconv.Itoa(port)) - var lc net.ListenConfig - listener, err := lc.Listen(ctx, "tcp", listenAddr) - if err != nil { - return fmt.Errorf("listen on %s: %w", listenAddr, err) - } - - log.Printf("SOCKS5 proxy listening on %s (auth=%v)", listenAddr, username != "") - if onReady != nil { - onReady() - } - - go func() { - <-ctx.Done() - if err := listener.Close(); err != nil { - logger.Debugf("SOCKS5 listener close error: %v", err) +func (c *Client) shutdown() { + c.connMu.Lock() + for _, conn := range c.connections { + if conn != nil { + _ = conn.Close() } + } + c.connMu.Unlock() + + for i, peer := range c.peers { + logger.Infof("closing peer %d", i) + _ = peer.Close() + } +} + +func (c *Client) pumpToMux(sid uint16, conn net.Conn) { + defer func() { + c.activeClients.Add(-1) + _ = c.mux.CloseStream(sid) + c.connMu.Lock() + delete(c.connections, sid) + c.connMu.Unlock() }() + buf := make([]byte, 16384) + totalSent := uint64(0) + lastLog := time.Now() + for { - conn, err := listener.Accept() + n, err := conn.Read(buf) if err != nil { - select { - case <-ctx.Done(): - c.closePeers() - return nil - default: - log.Printf("accept error: %v", err) - continue + if totalSent > 1024*1024 { + logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024)) } + return } - go c.handleSOCKS5(conn, username, password) - } -} + for !c.canSendData() { + time.Sleep(20 * time.Millisecond) + } -func (c *Client) closePeers() { - for _, peer := range c.peers { - if err := peer.Close(); err != nil { - logger.Debugf("Peer close error: %v", err) + if err := c.mux.SendData(sid, buf[:n]); err != nil { + return + } + + totalSent += uint64(n) //nolint:gosec + if time.Since(lastLog) > 5*time.Second { + logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024)) + lastLog = time.Now() } } } -//nolint:cyclop // SOCKS5 parsing is inherently stateful and mirrors the protocol handshake. -func (c *Client) handleSOCKS5(conn net.Conn, username, password string) { - defer func() { - if err := conn.Close(); err != nil { - logger.Debugf("SOCKS5 connection close error: %v", err) - } - }() - - buf := make([]byte, 513) - if !readSOCKSVersionAndMethods(conn, buf) { - return - } - - nmethods := buf[1] - if _, err := io.ReadFull(conn, buf[:nmethods]); err != nil { - return - } - - requireAuth := username != "" - wantMethod := byte(0x00) - if requireAuth { - wantMethod = 0x02 - } - - if !supportsMethod(buf[:nmethods], wantMethod) { - writeResponse(conn, replyUnsupportedSOCKSMethod()) - return - } - writeResponse(conn, []byte{5, wantMethod}) - - if requireAuth && !authenticateSOCKSUser(conn, buf, username, password) { - return - } - - addr, port, ok := readConnectTarget(conn, buf) - if !ok { - return - } - - sid := c.mux.OpenStream() - logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port) - log.Printf("sid=%d socks5 %s:%d", sid, addr, port) - - if !c.sendConnectRequest(sid, addr, port) { - return - } - - if !c.waitConnectResponse(conn, sid) { - return - } - - c.mux.ReadStream(sid) - writeResponse(conn, replySuccess()) - c.proxyStream(conn, sid) -} - -func readSOCKSVersionAndMethods(conn net.Conn, buf []byte) bool { - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return false - } - - return buf[0] == 5 -} - -func supportsMethod(methods []byte, wantMethod byte) bool { - for _, method := range methods { - if method == wantMethod { - return true - } - } - - return false -} - -func authenticateSOCKSUser(conn net.Conn, buf []byte, username, password string) bool { - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return false - } - if buf[0] != 0x01 { - return false - } - - ulen := int(buf[1]) - if _, err := io.ReadFull(conn, buf[:ulen+1]); err != nil { - return false - } - - gotUser := string(buf[:ulen]) - plen := int(buf[ulen]) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return false - } - - gotPass := string(buf[:plen]) - if gotUser != username || gotPass != password { - writeResponse(conn, replyAuthFailed()) - return false - } - - writeResponse(conn, replyAuthOK()) - return true -} - -func readConnectTarget(conn net.Conn, buf []byte) (string, uint16, bool) { - if _, err := io.ReadFull(conn, buf[:4]); err != nil { - return "", 0, false - } - - if buf[1] != 1 { - writeResponse(conn, replyCommandNotSupported()) - return "", 0, false - } - - addr, ok := readTargetAddress(conn, buf, buf[3]) - if !ok { - return "", 0, false - } - - if _, err := io.ReadFull(conn, buf[:2]); err != nil { - return "", 0, false - } - - return addr, binary.BigEndian.Uint16(buf[:2]), true -} - -func readTargetAddress(conn net.Conn, buf []byte, atyp byte) (string, bool) { - switch atyp { - case 1: - if _, err := io.ReadFull(conn, buf[:4]); err != nil { - return "", false - } - return fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]), true - case 3: - if _, err := io.ReadFull(conn, buf[:1]); err != nil { - return "", false - } - - length := buf[0] - if _, err := io.ReadFull(conn, buf[:length]); err != nil { - return "", false - } - return string(buf[:length]), true - default: - writeResponse(conn, replyAddressNotSupported()) - return "", false - } -} - -func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool { - reqData, err := json.Marshal(struct { - Cmd string `json:"cmd"` - Addr string `json:"addr"` - Port uint16 `json:"port"` - }{ - Cmd: "connect", - Addr: addr, - Port: port, - }) - if err != nil { - logger.Debugf("Connect request marshal error: %v", err) - return false - } - - if err := c.mux.SendData(sid, reqData); err != nil { - logger.Debugf("Connect request send error: %v", err) - return false - } - - return true -} - -func (c *Client) waitConnectResponse(conn net.Conn, sid uint16) bool { - dataReady := c.mux.WaitForData(sid) - timeout := time.NewTimer(10 * time.Second) - defer timeout.Stop() - - select { - case <-dataReady: - stream := c.mux.GetStream(sid) - if stream == nil || len(stream.RecvBuf()) == 0 { - writeResponse(conn, replyHostUnreachable()) - return false - } - case <-timeout.C: - writeResponse(conn, replyHostUnreachable()) - return false - } - - return true -} - -//nolint:cyclop // The stream pump handles two coordinated goroutines and shutdown races in one place. -func (c *Client) proxyStream(conn net.Conn, sid uint16) { - done := make(chan struct{}) - streamClosed := make(chan struct{}) - +func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { + c.wg.Add(1) go func() { - defer close(done) - buf := make([]byte, 32768) - for { - n, err := conn.Read(buf) - if err != nil { - if err := c.mux.CloseStream(sid); err != nil { - logger.Debugf("Close stream error: %v", err) - } - return - } - if err := c.mux.SendData(sid, buf[:n]); err != nil { - return - } - } - }() - - go func() { - defer close(streamClosed) - defer c.mux.CleanupDataChannel(sid) + defer c.wg.Done() ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { - case <-done: + case <-ctx.Done(): return case <-ticker.C: data := c.mux.ReadStream(sid) - if len(data) > 0 && !writeStreamData(conn, data) { - return + if len(data) > 0 { + if _, err := conn.Write(data); err != nil { + _ = c.mux.CloseStream(sid) + return + } } - if c.mux.StreamClosed(sid) { return } } } }() - - select { - case <-done: - case <-streamClosed: - } } -func writeStreamData(conn net.Conn, data []byte) bool { - for len(data) > 0 { - n, err := conn.Write(data) - if err != nil { +func (c *Client) canSendData() bool { + for _, peer := range c.peers { + if !peer.CanSend() { return false } - data = data[n:] } - return true } -func writeResponse(conn net.Conn, response []byte) { - if _, err := conn.Write(response); err != nil { - logger.Debugf("SOCKS5 response write error: %v", err) - } -} - -func replyUnsupportedSOCKSMethod() []byte { - return []byte{5, 0xFF} -} - -func replyAuthFailed() []byte { - return []byte{0x01, 0x01} -} - -func replyAuthOK() []byte { - return []byte{0x01, 0x00} -} - -func replyCommandNotSupported() []byte { - return []byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0} -} - -func replyAddressNotSupported() []byte { - return []byte{5, 8, 0, 1, 0, 0, 0, 0, 0, 0} -} - -func replyHostUnreachable() []byte { - return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0} -} - func replySuccess() []byte { return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0} } + +func replyHostUnreachable() []byte { + return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0} +}