Implement SOCKS5 proxy support in Server

Added SOCKS5 proxy support to the server, including new fields for proxy address and port in the Server struct. Updated the Run function and related methods to handle proxy connections.
This commit is contained in:
Kot-nikot
2026-04-11 09:43:13 +03:00
committed by GitHub
parent ae99bf689e
commit c1a30b677b

View File

@@ -7,6 +7,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net"
"sync"
@@ -22,16 +23,18 @@ import (
)
type Server struct {
peers []*telemost.Peer
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
wg sync.WaitGroup
dnsServer string
dnsCache sync.Map
resolver *net.Resolver
peers []*telemost.Peer
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
wg sync.WaitGroup
dnsServer string
dnsCache sync.Map
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
}
type ConnectRequest struct {
@@ -40,7 +43,7 @@ type ConnectRequest struct {
Port int `json:"port"`
}
func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string) error {
func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer, socksProxyAddr string, socksProxyPort int) error {
var key []byte
var err error
@@ -71,16 +74,18 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
}
s := &Server{
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]*telemost.Peer, 0),
dnsServer: dnsServer,
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]*telemost.Peer, 0),
dnsServer: dnsServer,
socksProxyAddr: socksProxyAddr,
socksProxyPort: socksProxyPort,
}
if dnsServer == "" {
dnsServer = "1.1.1.1:53"
}
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
@@ -109,7 +114,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return err
@@ -127,7 +132,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
log.Printf("Server peer %d reconnected - resetting multiplexer state", i)
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
@@ -136,7 +141,7 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
delete(s.connections, sid)
}
s.connMu.Unlock()
if dc != nil {
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
@@ -147,9 +152,9 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
return s.peers[idx].Send(encrypted)
})
}
s.mux.Reset()
log.Println("Server multiplexer reset complete")
})
@@ -167,14 +172,47 @@ func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer string
}
err = s.run(ctx)
log.Println("Waiting for server goroutines...")
s.wg.Wait()
log.Println("Server goroutines finished")
return err
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return err
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 auth failed")
}
req := []byte{5, 1, 0, 3}
req = append(req, byte(len(targetAddr)))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort))
if _, err := conn.Write(req); err != nil {
return err
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 connect failed: %d", resp[1])
}
return nil
}
func (s *Server) onData(data []byte) {
plaintext, err := s.cipher.Decrypt(data)
if err != nil {
@@ -186,7 +224,7 @@ func (s *Server) onData(data []byte) {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
@@ -205,7 +243,7 @@ func (s *Server) onData(data []byte) {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
@@ -228,7 +266,7 @@ func (s *Server) onData(data []byte) {
func (s *Server) run(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
@@ -240,21 +278,21 @@ func (s *Server) run(ctx context.Context) error {
}
}
s.connMu.Unlock()
log.Printf("Closing %d peer(s)...", len(s.peers))
for i, peer := range s.peers {
log.Printf("Closing peer %d...", i)
peer.Close()
}
log.Println("All peers closed")
return nil
case <-ticker.C:
}
sids := s.mux.GetStreams()
for _, sid := range sids {
go func(sid uint16) {
data := s.mux.ReadStream(sid)
@@ -262,7 +300,7 @@ func (s *Server) run(ctx context.Context) error {
s.connMu.RLock()
conn, exists := s.connections[sid]
s.connMu.RUnlock()
if exists && conn != nil {
if _, err := conn.Write(data); err != nil {
s.mux.CloseStream(sid)
@@ -315,28 +353,45 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
s.connMu.Unlock()
dialStart := time.Now()
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
var conn net.Conn
var err error
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err = dialer.Dial("tcp4", addr)
logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid)
} else {
proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err = dialer.Dial("tcp4", proxyAddr)
if err == nil {
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
conn.Close()
err = fmt.Errorf("SOCKS5 connect failed: %v", err)
}
}
logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid)
}
conn, err := dialer.Dial("tcp4", addr)
dialElapsed := time.Since(dialStart)
if err != nil {
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err)
go s.mux.CloseStream(sid)
return
}
logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid)
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
s.mux.SendData(sid, []byte{0x00})
@@ -348,11 +403,11 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
@@ -361,7 +416,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
}
return
}
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
@@ -369,7 +424,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
if err := s.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n)
if time.Since(lastLog) > 5*time.Second {
log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))