From c1a30b677bab8a4b7b4157c41a8a245edb3aec84 Mon Sep 17 00:00:00 2001 From: Kot-nikot <127394891+Kot-nikot@users.noreply.github.com> Date: Sat, 11 Apr 2026 09:43:13 +0300 Subject: [PATCH] Implement SOCKS5 proxy support in Server Added SOCKS5 proxy support to the server, including new fields for proxy address and port in the Server struct. Updated the Run function and related methods to handle proxy connections. --- internal/server/server.go | 153 ++++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 49 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 75efcc4..881ea12 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "io" "log" "net" "sync" @@ -22,16 +23,18 @@ import ( ) type Server struct { - peers []*telemost.Peer - cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - peerIdx atomic.Uint32 - wg sync.WaitGroup - dnsServer string - dnsCache sync.Map - resolver *net.Resolver + peers []*telemost.Peer + cipher *crypto.Cipher + mux *mux.Multiplexer + connections map[uint16]net.Conn + connMu sync.RWMutex + peerIdx atomic.Uint32 + wg sync.WaitGroup + dnsServer string + dnsCache sync.Map + resolver *net.Resolver + socksProxyAddr string + socksProxyPort int } type ConnectRequest struct { @@ -40,7 +43,7 @@ type ConnectRequest struct { Port int `json:"port"` } -func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string) error { +func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer, socksProxyAddr string, socksProxyPort int) error { var key []byte var err error @@ -71,16 +74,18 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } s := &Server{ - cipher: cipher, - connections: make(map[uint16]net.Conn), - peers: make([]*telemost.Peer, 0), - dnsServer: dnsServer, + cipher: cipher, + connections: make(map[uint16]net.Conn), + peers: make([]*telemost.Peer, 0), + dnsServer: dnsServer, + socksProxyAddr: socksProxyAddr, + socksProxyPort: socksProxyPort, } - + if dnsServer == "" { dnsServer = "1.1.1.1:53" } - + s.resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -109,7 +114,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } time.Sleep(10 * time.Millisecond) } - + encrypted, err := s.cipher.Encrypt(frame) if err != nil { return err @@ -127,7 +132,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { log.Printf("Server peer %d reconnected - resetting multiplexer state", i) - + s.connMu.Lock() for sid, conn := range s.connections { if conn != nil { @@ -136,7 +141,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string delete(s.connections, sid) } s.connMu.Unlock() - + if dc != nil { s.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := s.cipher.Encrypt(frame) @@ -147,9 +152,9 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string return s.peers[idx].Send(encrypted) }) } - + s.mux.Reset() - + log.Println("Server multiplexer reset complete") }) @@ -167,14 +172,47 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } err = s.run(ctx) - + log.Println("Waiting for server goroutines...") s.wg.Wait() log.Println("Server goroutines finished") - + return err } +func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { + if _, err := conn.Write([]byte{5, 1, 0}); err != nil { + return err + } + + resp := make([]byte, 2) + if _, err := io.ReadFull(conn, resp); err != nil { + return err + } + if resp[0] != 5 || resp[1] != 0 { + return fmt.Errorf("SOCKS5 auth failed") + } + + req := []byte{5, 1, 0, 3} + req = append(req, byte(len(targetAddr))) + req = append(req, []byte(targetAddr)...) + req = append(req, byte(targetPort>>8), byte(targetPort)) + + if _, err := conn.Write(req); err != nil { + return err + } + + resp = make([]byte, 10) + if _, err := io.ReadFull(conn, resp); err != nil { + return err + } + if resp[0] != 5 || resp[1] != 0 { + return fmt.Errorf("SOCKS5 connect failed: %d", resp[1]) + } + + return nil +} + func (s *Server) onData(data []byte) { plaintext, err := s.cipher.Decrypt(data) if err != nil { @@ -186,7 +224,7 @@ func (s *Server) onData(data []byte) { clientID := binary.BigEndian.Uint32(plaintext[0:4]) sid := binary.BigEndian.Uint16(plaintext[4:6]) length := binary.BigEndian.Uint16(plaintext[6:8]) - + if sid == 0xFFFF && length == 0xFFFF { log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) s.connMu.Lock() @@ -205,7 +243,7 @@ func (s *Server) onData(data []byte) { clientID := binary.BigEndian.Uint32(plaintext[0:4]) sid := binary.BigEndian.Uint16(plaintext[4:6]) length := binary.BigEndian.Uint16(plaintext[6:8]) - + if sid == 0xFFFF && length == 0xFFFF { log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) s.connMu.Lock() @@ -228,7 +266,7 @@ func (s *Server) onData(data []byte) { func (s *Server) run(ctx context.Context) error { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -240,21 +278,21 @@ func (s *Server) run(ctx context.Context) error { } } s.connMu.Unlock() - + log.Printf("Closing %d peer(s)...", len(s.peers)) for i, peer := range s.peers { log.Printf("Closing peer %d...", i) peer.Close() } log.Println("All peers closed") - + return nil - + case <-ticker.C: } - + sids := s.mux.GetStreams() - + for _, sid := range sids { go func(sid uint16) { data := s.mux.ReadStream(sid) @@ -262,7 +300,7 @@ func (s *Server) run(ctx context.Context) error { s.connMu.RLock() conn, exists := s.connections[sid] s.connMu.RUnlock() - + if exists && conn != nil { if _, err := conn.Write(data); err != nil { s.mux.CloseStream(sid) @@ -315,28 +353,45 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { s.connMu.Unlock() dialStart := time.Now() - - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - Resolver: s.resolver, + var conn net.Conn + var err error + + if s.socksProxyAddr == "" { + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + Resolver: s.resolver, + } + conn, err = dialer.Dial("tcp4", addr) + logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid) + } else { + proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort) + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + conn, err = dialer.Dial("tcp4", proxyAddr) + if err == nil { + if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { + conn.Close() + err = fmt.Errorf("SOCKS5 connect failed: %v", err) + } + } + logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid) } - - conn, err := dialer.Dial("tcp4", addr) + dialElapsed := time.Since(dialStart) - + if err != nil { log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err) go s.mux.CloseStream(sid) return } - - logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid) - + s.connMu.Lock() s.connections[sid] = conn s.connMu.Unlock() - + log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed) s.mux.SendData(sid, []byte{0x00}) @@ -348,11 +403,11 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { delete(s.connections, sid) s.connMu.Unlock() }() - + buf := make([]byte, 16384) totalSent := uint64(0) lastLog := time.Now() - + for { n, err := conn.Read(buf) if err != nil { @@ -361,7 +416,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { } return } - + for !s.canSendData() { time.Sleep(20 * time.Millisecond) } @@ -369,7 +424,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { if err := s.mux.SendData(sid, buf[:n]); err != nil { return } - + totalSent += uint64(n) if time.Since(lastLog) > 5*time.Second { log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))