mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-28 16:09:48 +00:00
Added SOCKS5 proxy support to the server, including new fields for proxy address and port in the Server struct. Updated the Run function and related methods to handle proxy connections.
445 lines
10 KiB
Go
445 lines
10 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/pion/webrtc/v4"
|
|
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
|
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
|
"github.com/openlibrecommunity/olcrtc/internal/mux"
|
|
"github.com/openlibrecommunity/olcrtc/internal/names"
|
|
"github.com/openlibrecommunity/olcrtc/internal/telemost"
|
|
)
|
|
|
|
type Server struct {
|
|
peers []*telemost.Peer
|
|
cipher *crypto.Cipher
|
|
mux *mux.Multiplexer
|
|
connections map[uint16]net.Conn
|
|
connMu sync.RWMutex
|
|
peerIdx atomic.Uint32
|
|
wg sync.WaitGroup
|
|
dnsServer string
|
|
dnsCache sync.Map
|
|
resolver *net.Resolver
|
|
socksProxyAddr string
|
|
socksProxyPort int
|
|
}
|
|
|
|
type ConnectRequest struct {
|
|
Cmd string `json:"cmd"`
|
|
Addr string `json:"addr"`
|
|
Port int `json:"port"`
|
|
}
|
|
|
|
func Run(ctx context.Context, roomURL, keyHex string, duo bool, dnsServer, socksProxyAddr string, socksProxyPort int) error {
|
|
var key []byte
|
|
var err error
|
|
|
|
if keyHex == "" {
|
|
key = make([]byte, 32)
|
|
if _, err := rand.Read(key); err != nil {
|
|
return err
|
|
}
|
|
log.Printf("Generated key: %x", key)
|
|
} else {
|
|
key, err = hex.DecodeString(keyHex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(key) != 32 {
|
|
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
|
|
}
|
|
}
|
|
|
|
keyStr := string(key)
|
|
if len(keyStr) != 32 {
|
|
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
|
|
}
|
|
|
|
cipher, err := crypto.NewCipher(keyStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s := &Server{
|
|
cipher: cipher,
|
|
connections: make(map[uint16]net.Conn),
|
|
peers: make([]*telemost.Peer, 0),
|
|
dnsServer: dnsServer,
|
|
socksProxyAddr: socksProxyAddr,
|
|
socksProxyPort: socksProxyPort,
|
|
}
|
|
|
|
if dnsServer == "" {
|
|
dnsServer = "1.1.1.1:53"
|
|
}
|
|
|
|
s.resolver = &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{Timeout: 3 * time.Second}
|
|
return d.DialContext(ctx, network, dnsServer)
|
|
},
|
|
}
|
|
|
|
peerCount := 1
|
|
if duo {
|
|
peerCount = 2
|
|
log.Println("Duo mode: using 2 parallel channels")
|
|
}
|
|
|
|
s.mux = mux.New(0, func(frame []byte) error {
|
|
for {
|
|
canSend := true
|
|
for _, peer := range s.peers {
|
|
if !peer.CanSend() {
|
|
canSend = false
|
|
break
|
|
}
|
|
}
|
|
if canSend {
|
|
break
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
encrypted, err := s.cipher.Encrypt(frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
|
|
return s.peers[idx].Send(encrypted)
|
|
})
|
|
|
|
for i := 0; i < peerCount; i++ {
|
|
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.peers = append(s.peers, peer)
|
|
|
|
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
|
log.Printf("Server peer %d reconnected - resetting multiplexer state", i)
|
|
|
|
s.connMu.Lock()
|
|
for sid, conn := range s.connections {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
delete(s.connections, sid)
|
|
}
|
|
s.connMu.Unlock()
|
|
|
|
if dc != nil {
|
|
s.mux.UpdateSendFunc(func(frame []byte) error {
|
|
encrypted, err := s.cipher.Encrypt(frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
|
|
return s.peers[idx].Send(encrypted)
|
|
})
|
|
}
|
|
|
|
s.mux.Reset()
|
|
|
|
log.Println("Server multiplexer reset complete")
|
|
})
|
|
|
|
log.Printf("Connecting peer %d to Telemost...", i)
|
|
if err := peer.Connect(ctx); err != nil {
|
|
return err
|
|
}
|
|
log.Printf("Peer %d connected", i)
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
peer.WatchConnection(ctx)
|
|
}()
|
|
}
|
|
|
|
err = s.run(ctx)
|
|
|
|
log.Println("Waiting for server goroutines...")
|
|
s.wg.Wait()
|
|
log.Println("Server goroutines finished")
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
|
|
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
|
|
return err
|
|
}
|
|
|
|
resp := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, resp); err != nil {
|
|
return err
|
|
}
|
|
if resp[0] != 5 || resp[1] != 0 {
|
|
return fmt.Errorf("SOCKS5 auth failed")
|
|
}
|
|
|
|
req := []byte{5, 1, 0, 3}
|
|
req = append(req, byte(len(targetAddr)))
|
|
req = append(req, []byte(targetAddr)...)
|
|
req = append(req, byte(targetPort>>8), byte(targetPort))
|
|
|
|
if _, err := conn.Write(req); err != nil {
|
|
return err
|
|
}
|
|
|
|
resp = make([]byte, 10)
|
|
if _, err := io.ReadFull(conn, resp); err != nil {
|
|
return err
|
|
}
|
|
if resp[0] != 5 || resp[1] != 0 {
|
|
return fmt.Errorf("SOCKS5 connect failed: %d", resp[1])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) onData(data []byte) {
|
|
plaintext, err := s.cipher.Decrypt(data)
|
|
if err != nil {
|
|
logger.Debug("Decrypt error: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(plaintext) >= 12 {
|
|
clientID := binary.BigEndian.Uint32(plaintext[0:4])
|
|
sid := binary.BigEndian.Uint16(plaintext[4:6])
|
|
length := binary.BigEndian.Uint16(plaintext[6:8])
|
|
|
|
if sid == 0xFFFF && length == 0xFFFF {
|
|
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
|
|
s.connMu.Lock()
|
|
for streamSid, conn := range s.connections {
|
|
stream := s.mux.GetStream(streamSid)
|
|
if stream != nil && stream.ClientID == clientID {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
delete(s.connections, streamSid)
|
|
}
|
|
}
|
|
s.connMu.Unlock()
|
|
}
|
|
} else if len(plaintext) >= 8 {
|
|
clientID := binary.BigEndian.Uint32(plaintext[0:4])
|
|
sid := binary.BigEndian.Uint16(plaintext[4:6])
|
|
length := binary.BigEndian.Uint16(plaintext[6:8])
|
|
|
|
if sid == 0xFFFF && length == 0xFFFF {
|
|
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", clientID)
|
|
s.connMu.Lock()
|
|
for streamSid, conn := range s.connections {
|
|
stream := s.mux.GetStream(streamSid)
|
|
if stream != nil && stream.ClientID == clientID {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
delete(s.connections, streamSid)
|
|
}
|
|
}
|
|
s.connMu.Unlock()
|
|
}
|
|
}
|
|
|
|
s.mux.HandleFrame(plaintext)
|
|
}
|
|
|
|
func (s *Server) run(ctx context.Context) error {
|
|
ticker := time.NewTicker(10 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Println("Server shutting down...")
|
|
s.connMu.Lock()
|
|
for _, conn := range s.connections {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
s.connMu.Unlock()
|
|
|
|
log.Printf("Closing %d peer(s)...", len(s.peers))
|
|
for i, peer := range s.peers {
|
|
log.Printf("Closing peer %d...", i)
|
|
peer.Close()
|
|
}
|
|
log.Println("All peers closed")
|
|
|
|
return nil
|
|
|
|
case <-ticker.C:
|
|
}
|
|
|
|
sids := s.mux.GetStreams()
|
|
|
|
for _, sid := range sids {
|
|
go func(sid uint16) {
|
|
data := s.mux.ReadStream(sid)
|
|
if len(data) > 0 {
|
|
s.connMu.RLock()
|
|
conn, exists := s.connections[sid]
|
|
s.connMu.RUnlock()
|
|
|
|
if exists && conn != nil {
|
|
if _, err := conn.Write(data); err != nil {
|
|
s.mux.CloseStream(sid)
|
|
conn.Close()
|
|
s.connMu.Lock()
|
|
delete(s.connections, sid)
|
|
s.connMu.Unlock()
|
|
}
|
|
} else {
|
|
var req ConnectRequest
|
|
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
|
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
|
|
s.connMu.Lock()
|
|
if oldConn, exists := s.connections[sid]; exists && oldConn != nil {
|
|
oldConn.Close()
|
|
}
|
|
s.connMu.Unlock()
|
|
go s.handleConnect(sid, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
if s.mux.StreamClosed(sid) {
|
|
s.connMu.Lock()
|
|
conn, exists := s.connections[sid]
|
|
if exists && conn != nil {
|
|
conn.Close()
|
|
delete(s.connections, sid)
|
|
}
|
|
s.connMu.Unlock()
|
|
}
|
|
}(sid)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
|
|
startTime := time.Now()
|
|
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
|
|
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
|
|
log.Printf("[SERVER] sid=%d CONNECT_START %s", sid, addr)
|
|
|
|
s.connMu.Lock()
|
|
oldConn, exists := s.connections[sid]
|
|
if exists && oldConn != nil {
|
|
log.Printf("Closing old connection for sid=%d", sid)
|
|
oldConn.Close()
|
|
delete(s.connections, sid)
|
|
}
|
|
s.connMu.Unlock()
|
|
|
|
dialStart := time.Now()
|
|
var conn net.Conn
|
|
var err error
|
|
|
|
if s.socksProxyAddr == "" {
|
|
dialer := &net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
Resolver: s.resolver,
|
|
}
|
|
conn, err = dialer.Dial("tcp4", addr)
|
|
logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid)
|
|
} else {
|
|
proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort)
|
|
dialer := &net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
conn, err = dialer.Dial("tcp4", proxyAddr)
|
|
if err == nil {
|
|
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
|
conn.Close()
|
|
err = fmt.Errorf("SOCKS5 connect failed: %v", err)
|
|
}
|
|
}
|
|
logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid)
|
|
}
|
|
|
|
dialElapsed := time.Since(dialStart)
|
|
|
|
if err != nil {
|
|
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err)
|
|
go s.mux.CloseStream(sid)
|
|
return
|
|
}
|
|
|
|
s.connMu.Lock()
|
|
s.connections[sid] = conn
|
|
s.connMu.Unlock()
|
|
|
|
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
|
|
|
|
s.mux.SendData(sid, []byte{0x00})
|
|
|
|
go func() {
|
|
defer func() {
|
|
s.mux.CloseStream(sid)
|
|
s.connMu.Lock()
|
|
delete(s.connections, sid)
|
|
s.connMu.Unlock()
|
|
}()
|
|
|
|
buf := make([]byte, 16384)
|
|
totalSent := uint64(0)
|
|
lastLog := time.Now()
|
|
|
|
for {
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
if totalSent > 1024*1024 {
|
|
log.Printf("[SERVER] sid=%d TRANSFER_COMPLETE total=%d MB", sid, totalSent/(1024*1024))
|
|
}
|
|
return
|
|
}
|
|
|
|
for !s.canSendData() {
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
|
|
if err := s.mux.SendData(sid, buf[:n]); err != nil {
|
|
return
|
|
}
|
|
|
|
totalSent += uint64(n)
|
|
if time.Since(lastLog) > 5*time.Second {
|
|
log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))
|
|
lastLog = time.Now()
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *Server) canSendData() bool {
|
|
for _, peer := range s.peers {
|
|
if !peer.CanSend() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|