refactor(client): replace log.Printf with logger and standardize

This commit is contained in:
zarazaex69
2026-04-20 05:29:27 +03:00
parent a0b6ef0f35
commit d1d82ff6a3

View File

@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -26,232 +24,358 @@ import (
)
var (
errInvalidKeyLength = errors.New("key must be 32 bytes")
errInvalidKeyStringLength = errors.New("key string length must be 32")
errNoConnectedPeers = errors.New("no connected peers available")
ErrKeySize = errors.New("key must be 32 bytes")
ErrKeyStringLength = errors.New("key string length must be 32")
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
ErrNoPeers = errors.New("no peers available")
ErrEncryptFailed = errors.New("encrypt failed")
)
// Client handles local SOCKS5 connections and tunnels them through WebRTC.
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
type Client struct {
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
clientID uint32
peerIdx atomic.Uint32
wg sync.WaitGroup
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
}
const defaultSOCKSListenHost = "127.0.0.1"
// Run starts the client with the specified parameters.
func Run(
ctx context.Context,
providerName,
roomURL,
keyHex string,
socksPort int,
socksHost,
socksUser,
socksPass string,
) error {
return RunWithReady(ctx, providerName, roomURL, keyHex, socksPort, socksHost, socksUser, socksPass, nil)
}
// RunWithReady starts the client and calls onReady when it is listening.
func RunWithReady(
ctx context.Context,
providerName,
roomURL,
keyHex string,
socksPort int,
socksHost,
socksUser,
socksPass string,
onReady func(),
localAddr string,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
key, err := decodeKey(keyHex)
cipher, err := setupCipher(keyHex)
if err != nil {
return fmt.Errorf("decodeKey failed: %w", err)
return fmt.Errorf("setupCipher failed: %w", err)
}
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("%w: got %d", errInvalidKeyStringLength, len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return fmt.Errorf("create cipher: %w", err)
clientIDBytes := make([]byte, 4)
if _, err := rand.Read(clientIDBytes); err != nil {
return fmt.Errorf("failed to generate client ID: %w", err)
}
clientID := binary.BigEndian.Uint32(clientIDBytes)
c := &Client{
cipher: cipher,
clientID: uint32(time.Now().UnixNano() & 0xFFFFFFFF),
peers: make([]provider.Provider, 0, 1),
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]provider.Provider, 0),
clientID: clientID,
dnsServer: dnsServer,
}
c.mux = mux.New(c.clientID, c.sendFrame)
c.setupMux()
for peerID := range 1 {
if err := c.addPeer(runCtx, providerName, roomURL, peerID, cancel); err != nil {
const peerCount = 1
for i := range peerCount {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
}
}
time.Sleep(100 * time.Millisecond)
c.sendResetSignal()
ln, err := net.Listen("tcp", localAddr)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer ln.Close()
err = c.runSOCKS5(runCtx, socksHost, socksPort, socksUser, socksPass, onReady)
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
go c.acceptLoop(runCtx, ln)
<-runCtx.Done()
c.shutdown()
c.wg.Wait()
return err
}
func decodeKey(keyHex string) ([]byte, error) {
if keyHex == "" {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generate random key: %w", err)
}
log.Printf("Generated key: %x", key)
return key, nil
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("decode hex key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w: got %d", errInvalidKeyLength, len(key))
}
return key, nil
}
func (c *Client) sendFrame(frame []byte) error {
waitUntilPeersCanSend(c.peers)
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("encrypt outgoing frame: %w", err)
}
peer, err := c.nextPeer()
if err != nil {
return err
}
if err := peer.Send(encrypted); err != nil {
return fmt.Errorf("send frame via peer: %w", err)
}
return nil
}
func waitUntilPeersCanSend(peers []provider.Provider) {
for {
canSend := true
for _, peer := range peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
return
}
time.Sleep(10 * time.Millisecond)
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, ErrKeySize
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, ErrKeyStringLength
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
// nextPeer returns the next provider for load balancing.
//
//nolint:ireturn
func (c *Client) nextPeer() (provider.Provider, error) {
switch len(c.peers) {
case 0:
return nil, errNoConnectedPeers
case 1:
return c.peers[0], nil
default:
return c.peers[int(c.peerIdx.Add(1)%2)], nil
}
func (c *Client) setupMux() {
c.mux = mux.New(c.clientID, func(frame []byte) error {
for {
canSend := true
for _, peer := range c.peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
}
func (c *Client) addPeer(
runCtx context.Context,
ctx context.Context,
providerName,
roomURL string,
peerID int,
cancel context.CancelFunc,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
peer, err := provider.New(runCtx, providerName, provider.Config{
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
peer, err := provider.New(ctx, providerName, provider.Config{
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
DNSServer: dnsServer,
ProxyAddr: socksProxyAddr,
ProxyPort: socksProxyPort,
})
if err != nil {
return fmt.Errorf("create peer %d: %w", peerID, err)
return fmt.Errorf("failed to create peer: %w", err)
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
cancel()
})
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.onReconnect(peerID, dc)
})
c.peers = append(c.peers, peer)
log.Printf("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(runCtx); err != nil {
return fmt.Errorf("connect peer %d: %w", peerID, err)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.handlePeerReconnect(peerID, dc)
})
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
}
log.Printf("Peer %d connected", peerID)
logger.Infof("Peer %d connected", peerID)
c.wg.Add(1)
go func() {
defer c.wg.Done()
peer.WatchConnection(runCtx)
peer.WatchConnection(ctx)
}()
// Send initial reset to clean up any stale connections for this clientID on server
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send initial client reset: %v", err)
}
return nil
}
func (c *Client) onReconnect(peerID int, dc *webrtc.DataChannel) {
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
c.connMu.Lock()
for sid, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
delete(c.connections, sid)
}
c.connMu.Unlock()
if dc != nil {
c.mux.UpdateSendFunc(c.sendFrame)
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
c.mux.Reset()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after reconnect: %v", err)
}
}
}
func (c *Client) sendResetSignal() {
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
encrypted, err := c.cipher.Encrypt(resetFrame)
if err != nil {
log.Printf("Failed to encrypt reset signal: %v", err)
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
select {
case <-ctx.Done():
return
default:
conn, err := ln.Accept()
if err != nil {
logger.Debugf("Accept error: %v", err)
continue
}
go c.handleSOCKS5(ctx, conn)
}
}
}
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
defer conn.Close()
if err := c.socks5Handshake(conn); err != nil {
logger.Debugf("SOCKS5 handshake failed: %v", err)
return
}
for _, peer := range c.peers {
if err := peer.Send(encrypted); err != nil {
log.Printf("Failed to send reset signal to server: %v", err)
}
addr, port, err := c.socks5Request(conn)
if err != nil {
logger.Debugf("SOCKS5 request failed: %v", err)
return
}
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
sid := c.mux.OpenStream()
c.connMu.Lock()
c.connections[sid] = conn
c.connMu.Unlock()
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
req := map[string]any{
"cmd": "connect",
"addr": addr,
"port": port,
}
reqData, _ := json.Marshal(req)
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Warnf("sid=%d send connect failed: %v", sid, err)
return
}
dataReady := c.mux.WaitForData(sid)
select {
case <-dataReady:
resp := c.mux.ReadStream(sid)
if len(resp) > 0 && resp[0] == 0x00 {
if _, err := conn.Write(replySuccess()); err != nil {
return
}
} else {
_, _ = conn.Write(replyHostUnreachable())
return
}
case <-time.After(15 * time.Second):
_, _ = conn.Write(replyHostUnreachable())
c.mux.CleanupDataChannel(sid)
return
case <-ctx.Done():
return
}
c.mux.CleanupDataChannel(sid)
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return err
}
if buf[0] != 5 {
return ErrInvalidSocks5
}
methods := make([]byte, int(buf[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
_, err := conn.Write([]byte{5, 0})
return err
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
if buf[0] != 5 || buf[1] != 1 {
return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1])
}
var addr string
switch buf[3] {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", 0, err
}
addr = string(domain)
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
default:
return "", 0, fmt.Errorf("unsupported address type: %d", buf[3])
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, err
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) onData(data []byte) {
@@ -264,347 +388,100 @@ func (c *Client) onData(data []byte) {
c.mux.HandleFrame(plaintext)
}
func (c *Client) runSOCKS5(
ctx context.Context,
host string,
port int,
username,
password string,
onReady func(),
) error {
if host == "" {
host = defaultSOCKSListenHost
}
listenAddr := net.JoinHostPort(host, strconv.Itoa(port))
var lc net.ListenConfig
listener, err := lc.Listen(ctx, "tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", listenAddr, err)
}
log.Printf("SOCKS5 proxy listening on %s (auth=%v)", listenAddr, username != "")
if onReady != nil {
onReady()
}
go func() {
<-ctx.Done()
if err := listener.Close(); err != nil {
logger.Debugf("SOCKS5 listener close error: %v", err)
func (c *Client) shutdown() {
c.connMu.Lock()
for _, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
}
c.connMu.Unlock()
for i, peer := range c.peers {
logger.Infof("closing peer %d", i)
_ = peer.Close()
}
}
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
c.activeClients.Add(-1)
_ = c.mux.CloseStream(sid)
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
conn, err := listener.Accept()
n, err := conn.Read(buf)
if err != nil {
select {
case <-ctx.Done():
c.closePeers()
return nil
default:
log.Printf("accept error: %v", err)
continue
if totalSent > 1024*1024 {
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
go c.handleSOCKS5(conn, username, password)
}
}
for !c.canSendData() {
time.Sleep(20 * time.Millisecond)
}
func (c *Client) closePeers() {
for _, peer := range c.peers {
if err := peer.Close(); err != nil {
logger.Debugf("Peer close error: %v", err)
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
}
//nolint:cyclop // SOCKS5 parsing is inherently stateful and mirrors the protocol handshake.
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("SOCKS5 connection close error: %v", err)
}
}()
buf := make([]byte, 513)
if !readSOCKSVersionAndMethods(conn, buf) {
return
}
nmethods := buf[1]
if _, err := io.ReadFull(conn, buf[:nmethods]); err != nil {
return
}
requireAuth := username != ""
wantMethod := byte(0x00)
if requireAuth {
wantMethod = 0x02
}
if !supportsMethod(buf[:nmethods], wantMethod) {
writeResponse(conn, replyUnsupportedSOCKSMethod())
return
}
writeResponse(conn, []byte{5, wantMethod})
if requireAuth && !authenticateSOCKSUser(conn, buf, username, password) {
return
}
addr, port, ok := readConnectTarget(conn, buf)
if !ok {
return
}
sid := c.mux.OpenStream()
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
log.Printf("sid=%d socks5 %s:%d", sid, addr, port)
if !c.sendConnectRequest(sid, addr, port) {
return
}
if !c.waitConnectResponse(conn, sid) {
return
}
c.mux.ReadStream(sid)
writeResponse(conn, replySuccess())
c.proxyStream(conn, sid)
}
func readSOCKSVersionAndMethods(conn net.Conn, buf []byte) bool {
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return false
}
return buf[0] == 5
}
func supportsMethod(methods []byte, wantMethod byte) bool {
for _, method := range methods {
if method == wantMethod {
return true
}
}
return false
}
func authenticateSOCKSUser(conn net.Conn, buf []byte, username, password string) bool {
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return false
}
if buf[0] != 0x01 {
return false
}
ulen := int(buf[1])
if _, err := io.ReadFull(conn, buf[:ulen+1]); err != nil {
return false
}
gotUser := string(buf[:ulen])
plen := int(buf[ulen])
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return false
}
gotPass := string(buf[:plen])
if gotUser != username || gotPass != password {
writeResponse(conn, replyAuthFailed())
return false
}
writeResponse(conn, replyAuthOK())
return true
}
func readConnectTarget(conn net.Conn, buf []byte) (string, uint16, bool) {
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return "", 0, false
}
if buf[1] != 1 {
writeResponse(conn, replyCommandNotSupported())
return "", 0, false
}
addr, ok := readTargetAddress(conn, buf, buf[3])
if !ok {
return "", 0, false
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return "", 0, false
}
return addr, binary.BigEndian.Uint16(buf[:2]), true
}
func readTargetAddress(conn net.Conn, buf []byte, atyp byte) (string, bool) {
switch atyp {
case 1:
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return "", false
}
return fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]), true
case 3:
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
return "", false
}
length := buf[0]
if _, err := io.ReadFull(conn, buf[:length]); err != nil {
return "", false
}
return string(buf[:length]), true
default:
writeResponse(conn, replyAddressNotSupported())
return "", false
}
}
func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
reqData, err := json.Marshal(struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port uint16 `json:"port"`
}{
Cmd: "connect",
Addr: addr,
Port: port,
})
if err != nil {
logger.Debugf("Connect request marshal error: %v", err)
return false
}
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Debugf("Connect request send error: %v", err)
return false
}
return true
}
func (c *Client) waitConnectResponse(conn net.Conn, sid uint16) bool {
dataReady := c.mux.WaitForData(sid)
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
select {
case <-dataReady:
stream := c.mux.GetStream(sid)
if stream == nil || len(stream.RecvBuf()) == 0 {
writeResponse(conn, replyHostUnreachable())
return false
}
case <-timeout.C:
writeResponse(conn, replyHostUnreachable())
return false
}
return true
}
//nolint:cyclop // The stream pump handles two coordinated goroutines and shutdown races in one place.
func (c *Client) proxyStream(conn net.Conn, sid uint16) {
done := make(chan struct{})
streamClosed := make(chan struct{})
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
c.wg.Add(1)
go func() {
defer close(done)
buf := make([]byte, 32768)
for {
n, err := conn.Read(buf)
if err != nil {
if err := c.mux.CloseStream(sid); err != nil {
logger.Debugf("Close stream error: %v", err)
}
return
}
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
}
}()
go func() {
defer close(streamClosed)
defer c.mux.CleanupDataChannel(sid)
defer c.wg.Done()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-done:
case <-ctx.Done():
return
case <-ticker.C:
data := c.mux.ReadStream(sid)
if len(data) > 0 && !writeStreamData(conn, data) {
return
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = c.mux.CloseStream(sid)
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}()
select {
case <-done:
case <-streamClosed:
}
}
func writeStreamData(conn net.Conn, data []byte) bool {
for len(data) > 0 {
n, err := conn.Write(data)
if err != nil {
func (c *Client) canSendData() bool {
for _, peer := range c.peers {
if !peer.CanSend() {
return false
}
data = data[n:]
}
return true
}
func writeResponse(conn net.Conn, response []byte) {
if _, err := conn.Write(response); err != nil {
logger.Debugf("SOCKS5 response write error: %v", err)
}
}
func replyUnsupportedSOCKSMethod() []byte {
return []byte{5, 0xFF}
}
func replyAuthFailed() []byte {
return []byte{0x01, 0x01}
}
func replyAuthOK() []byte {
return []byte{0x01, 0x00}
}
func replyCommandNotSupported() []byte {
return []byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyAddressNotSupported() []byte {
return []byte{5, 8, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}