Files
olcrtc/internal/server/server.go
Kot-nikot c1a30b677b Implement SOCKS5 proxy support in Server
Added SOCKS5 proxy support to the server, including new fields for proxy address and port in the Server struct. Updated the Run function and related methods to handle proxy connections.
2026-04-11 09:43:13 +03:00

445 lines
10 KiB
Go

package server
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/webrtc/v4"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/telemost"
)
type Server struct {
peers []*telemost.Peer
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
wg sync.WaitGroup
dnsServer string
dnsCache sync.Map
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
}
type ConnectRequest struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer, socksProxyAddr string, socksProxyPort int) error {
var key []byte
var err error
if keyHex == "" {
key = make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return err
}
log.Printf("Generated key: %x", key)
} else {
key, err = hex.DecodeString(keyHex)
if err != nil {
return err
}
if len(key) != 32 {
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
}
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return err
}
s := &Server{
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]*telemost.Peer, 0),
dnsServer: dnsServer,
socksProxyAddr: socksProxyAddr,
socksProxyPort: socksProxyPort,
}
if dnsServer == "" {
dnsServer = "1.1.1.1:53"
}
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: 3 * time.Second}
return d.DialContext(ctx, network, dnsServer)
},
}
peerCount := 1
if duo {
peerCount = 2
log.Println("Duo mode: using 2 parallel channels")
}
s.mux = mux.New(0, func(frame []byte) error {
for {
canSend := true
for _, peer := range s.peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return err
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
return s.peers[idx].Send(encrypted)
})
for i := 0; i < peerCount; i++ {
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
if err != nil {
return err
}
s.peers = append(s.peers, peer)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
log.Printf("Server peer %d reconnected - resetting multiplexer state", i)
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
conn.Close()
}
delete(s.connections, sid)
}
s.connMu.Unlock()
if dc != nil {
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return err
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
return s.peers[idx].Send(encrypted)
})
}
s.mux.Reset()
log.Println("Server multiplexer reset complete")
})
log.Printf("Connecting peer %d to Telemost...", i)
if err := peer.Connect(ctx); err != nil {
return err
}
log.Printf("Peer %d connected", i)
s.wg.Add(1)
go func() {
defer s.wg.Done()
peer.WatchConnection(ctx)
}()
}
err = s.run(ctx)
log.Println("Waiting for server goroutines...")
s.wg.Wait()
log.Println("Server goroutines finished")
return err
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return err
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 auth failed")
}
req := []byte{5, 1, 0, 3}
req = append(req, byte(len(targetAddr)))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort))
if _, err := conn.Write(req); err != nil {
return err
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 connect failed: %d", resp[1])
}
return nil
}
func (s *Server) onData(data []byte) {
plaintext, err := s.cipher.Decrypt(data)
if err != nil {
logger.Debug("Decrypt error: %v", err)
return
}
if len(plaintext) >= 12 {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
}
delete(s.connections, streamSid)
}
}
s.connMu.Unlock()
}
} else if len(plaintext) >= 8 {
clientID := binary.BigEndian.Uint32(plaintext[0:4])
sid := binary.BigEndian.Uint16(plaintext[4:6])
length := binary.BigEndian.Uint16(plaintext[6:8])
if sid == 0xFFFF && length == 0xFFFF {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
s.connMu.Lock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
}
delete(s.connections, streamSid)
}
}
s.connMu.Unlock()
}
}
s.mux.HandleFrame(plaintext)
}
func (s *Server) run(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Println("Server shutting down...")
s.connMu.Lock()
for _, conn := range s.connections {
if conn != nil {
conn.Close()
}
}
s.connMu.Unlock()
log.Printf("Closing %d peer(s)...", len(s.peers))
for i, peer := range s.peers {
log.Printf("Closing peer %d...", i)
peer.Close()
}
log.Println("All peers closed")
return nil
case <-ticker.C:
}
sids := s.mux.GetStreams()
for _, sid := range sids {
go func(sid uint16) {
data := s.mux.ReadStream(sid)
if len(data) > 0 {
s.connMu.RLock()
conn, exists := s.connections[sid]
s.connMu.RUnlock()
if exists && conn != nil {
if _, err := conn.Write(data); err != nil {
s.mux.CloseStream(sid)
conn.Close()
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}
} else {
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
s.connMu.Lock()
if oldConn, exists := s.connections[sid]; exists && oldConn != nil {
oldConn.Close()
}
s.connMu.Unlock()
go s.handleConnect(sid, req)
}
}
}
if s.mux.StreamClosed(sid) {
s.connMu.Lock()
conn, exists := s.connections[sid]
if exists && conn != nil {
conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
}(sid)
}
}
}
func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
startTime := time.Now()
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
log.Printf("[SERVER] sid=%d CONNECT_START %s", sid, addr)
s.connMu.Lock()
oldConn, exists := s.connections[sid]
if exists && oldConn != nil {
log.Printf("Closing old connection for sid=%d", sid)
oldConn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
dialStart := time.Now()
var conn net.Conn
var err error
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err = dialer.Dial("tcp4", addr)
logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid)
} else {
proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err = dialer.Dial("tcp4", proxyAddr)
if err == nil {
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
conn.Close()
err = fmt.Errorf("SOCKS5 connect failed: %v", err)
}
}
logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid)
}
dialElapsed := time.Since(dialStart)
if err != nil {
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err)
go s.mux.CloseStream(sid)
return
}
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
s.mux.SendData(sid, []byte{0x00})
go func() {
defer func() {
s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
log.Printf("[SERVER] sid=%d TRANSFER_COMPLETE total=%d MB", sid, totalSent/(1024*1024))
}
return
}
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := s.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n)
if time.Since(lastLog) > 5*time.Second {
log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
}()
}
func (s *Server) canSendData() bool {
for _, peer := range s.peers {
if !peer.CanSend() {
return false
}
}
return true
}