diff --git a/internal/server/server.go b/internal/server/server.go index 75efcc4..881ea12 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "io" "log" "net" "sync" @@ -22,16 +23,18 @@ import ( ) type Server struct { - peers []*telemost.Peer - cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - peerIdx atomic.Uint32 - wg sync.WaitGroup - dnsServer string - dnsCache sync.Map - resolver *net.Resolver + peers []*telemost.Peer + cipher *crypto.Cipher + mux *mux.Multiplexer + connections map[uint16]net.Conn + connMu sync.RWMutex + peerIdx atomic.Uint32 + wg sync.WaitGroup + dnsServer string + dnsCache sync.Map + resolver *net.Resolver + socksProxyAddr string + socksProxyPort int } type ConnectRequest struct { @@ -40,7 +43,7 @@ type ConnectRequest struct { Port int `json:"port"` } -func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string) error { +func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer, socksProxyAddr string, socksProxyPort int) error { var key []byte var err error @@ -71,16 +74,18 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } s := &Server{ - cipher: cipher, - connections: make(map[uint16]net.Conn), - peers: make([]*telemost.Peer, 0), - dnsServer: dnsServer, + cipher: cipher, + connections: make(map[uint16]net.Conn), + peers: make([]*telemost.Peer, 0), + dnsServer: dnsServer, + socksProxyAddr: socksProxyAddr, + socksProxyPort: socksProxyPort, } - + if dnsServer == "" { dnsServer = "1.1.1.1:53" } - + s.resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -109,7 +114,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } time.Sleep(10 * time.Millisecond) } - + encrypted, err := s.cipher.Encrypt(frame) if err != nil { return err @@ -127,7 +132,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { log.Printf("Server peer %d reconnected - resetting multiplexer state", i) - + s.connMu.Lock() for sid, conn := range s.connections { if conn != nil { @@ -136,7 +141,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string delete(s.connections, sid) } s.connMu.Unlock() - + if dc != nil { s.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := s.cipher.Encrypt(frame) @@ -147,9 +152,9 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string return s.peers[idx].Send(encrypted) }) } - + s.mux.Reset() - + log.Println("Server multiplexer reset complete") }) @@ -167,14 +172,47 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string } err = s.run(ctx) - + log.Println("Waiting for server goroutines...") s.wg.Wait() log.Println("Server goroutines finished") - + return err } +func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { + if _, err := conn.Write([]byte{5, 1, 0}); err != nil { + return err + } + + resp := make([]byte, 2) + if _, err := io.ReadFull(conn, resp); err != nil { + return err + } + if resp[0] != 5 || resp[1] != 0 { + return fmt.Errorf("SOCKS5 auth failed") + } + + req := []byte{5, 1, 0, 3} + req = append(req, byte(len(targetAddr))) + req = append(req, []byte(targetAddr)...) + req = append(req, byte(targetPort>>8), byte(targetPort)) + + if _, err := conn.Write(req); err != nil { + return err + } + + resp = make([]byte, 10) + if _, err := io.ReadFull(conn, resp); err != nil { + return err + } + if resp[0] != 5 || resp[1] != 0 { + return fmt.Errorf("SOCKS5 connect failed: %d", resp[1]) + } + + return nil +} + func (s *Server) onData(data []byte) { plaintext, err := s.cipher.Decrypt(data) if err != nil { @@ -186,7 +224,7 @@ func (s *Server) onData(data []byte) { clientID := binary.BigEndian.Uint32(plaintext[0:4]) sid := binary.BigEndian.Uint16(plaintext[4:6]) length := binary.BigEndian.Uint16(plaintext[6:8]) - + if sid == 0xFFFF && length == 0xFFFF { log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) s.connMu.Lock() @@ -205,7 +243,7 @@ func (s *Server) onData(data []byte) { clientID := binary.BigEndian.Uint32(plaintext[0:4]) sid := binary.BigEndian.Uint16(plaintext[4:6]) length := binary.BigEndian.Uint16(plaintext[6:8]) - + if sid == 0xFFFF && length == 0xFFFF { log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID) s.connMu.Lock() @@ -228,7 +266,7 @@ func (s *Server) onData(data []byte) { func (s *Server) run(ctx context.Context) error { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -240,21 +278,21 @@ func (s *Server) run(ctx context.Context) error { } } s.connMu.Unlock() - + log.Printf("Closing %d peer(s)...", len(s.peers)) for i, peer := range s.peers { log.Printf("Closing peer %d...", i) peer.Close() } log.Println("All peers closed") - + return nil - + case <-ticker.C: } - + sids := s.mux.GetStreams() - + for _, sid := range sids { go func(sid uint16) { data := s.mux.ReadStream(sid) @@ -262,7 +300,7 @@ func (s *Server) run(ctx context.Context) error { s.connMu.RLock() conn, exists := s.connections[sid] s.connMu.RUnlock() - + if exists && conn != nil { if _, err := conn.Write(data); err != nil { s.mux.CloseStream(sid) @@ -315,28 +353,45 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { s.connMu.Unlock() dialStart := time.Now() - - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - Resolver: s.resolver, + var conn net.Conn + var err error + + if s.socksProxyAddr == "" { + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + Resolver: s.resolver, + } + conn, err = dialer.Dial("tcp4", addr) + logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid) + } else { + proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort) + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + conn, err = dialer.Dial("tcp4", proxyAddr) + if err == nil { + if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { + conn.Close() + err = fmt.Errorf("SOCKS5 connect failed: %v", err) + } + } + logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid) } - - conn, err := dialer.Dial("tcp4", addr) + dialElapsed := time.Since(dialStart) - + if err != nil { log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err) go s.mux.CloseStream(sid) return } - - logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid) - + s.connMu.Lock() s.connections[sid] = conn s.connMu.Unlock() - + log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed) s.mux.SendData(sid, []byte{0x00}) @@ -348,11 +403,11 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { delete(s.connections, sid) s.connMu.Unlock() }() - + buf := make([]byte, 16384) totalSent := uint64(0) lastLog := time.Now() - + for { n, err := conn.Read(buf) if err != nil { @@ -361,7 +416,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { } return } - + for !s.canSendData() { time.Sleep(20 * time.Millisecond) } @@ -369,7 +424,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { if err := s.mux.SendData(sid, buf[:n]); err != nil { return } - + totalSent += uint64(n) if time.Since(lastLog) > 5*time.Second { log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))