mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-30 08:59:43 +00:00
refactor(client): replace log.Printf with logger and standardize
This commit is contained in:
@@ -10,9 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -26,232 +24,358 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidKeyLength = errors.New("key must be 32 bytes")
|
||||
errInvalidKeyStringLength = errors.New("key string length must be 32")
|
||||
errNoConnectedPeers = errors.New("no connected peers available")
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
ErrKeyStringLength = errors.New("key string length must be 32")
|
||||
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
|
||||
ErrNoPeers = errors.New("no peers available")
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
)
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them through WebRTC.
|
||||
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
|
||||
type Client struct {
|
||||
peers []provider.Provider
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
clientID uint32
|
||||
peerIdx atomic.Uint32
|
||||
wg sync.WaitGroup
|
||||
peers []provider.Provider
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
connections map[uint16]net.Conn
|
||||
connMu sync.RWMutex
|
||||
peerIdx atomic.Uint32
|
||||
clientID uint32
|
||||
activeClients atomic.Int32
|
||||
wg sync.WaitGroup
|
||||
dnsServer string
|
||||
}
|
||||
|
||||
const defaultSOCKSListenHost = "127.0.0.1"
|
||||
|
||||
// Run starts the client with the specified parameters.
|
||||
func Run(
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
socksPort int,
|
||||
socksHost,
|
||||
socksUser,
|
||||
socksPass string,
|
||||
) error {
|
||||
return RunWithReady(ctx, providerName, roomURL, keyHex, socksPort, socksHost, socksUser, socksPass, nil)
|
||||
}
|
||||
|
||||
// RunWithReady starts the client and calls onReady when it is listening.
|
||||
func RunWithReady(
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
socksPort int,
|
||||
socksHost,
|
||||
socksUser,
|
||||
socksPass string,
|
||||
onReady func(),
|
||||
localAddr string,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
) error {
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
key, err := decodeKey(keyHex)
|
||||
cipher, err := setupCipher(keyHex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decodeKey failed: %w", err)
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return fmt.Errorf("%w: got %d", errInvalidKeyStringLength, len(keyStr))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create cipher: %w", err)
|
||||
clientIDBytes := make([]byte, 4)
|
||||
if _, err := rand.Read(clientIDBytes); err != nil {
|
||||
return fmt.Errorf("failed to generate client ID: %w", err)
|
||||
}
|
||||
clientID := binary.BigEndian.Uint32(clientIDBytes)
|
||||
|
||||
c := &Client{
|
||||
cipher: cipher,
|
||||
clientID: uint32(time.Now().UnixNano() & 0xFFFFFFFF),
|
||||
peers: make([]provider.Provider, 0, 1),
|
||||
cipher: cipher,
|
||||
connections: make(map[uint16]net.Conn),
|
||||
peers: make([]provider.Provider, 0),
|
||||
clientID: clientID,
|
||||
dnsServer: dnsServer,
|
||||
}
|
||||
|
||||
c.mux = mux.New(c.clientID, c.sendFrame)
|
||||
c.setupMux()
|
||||
|
||||
for peerID := range 1 {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, peerID, cancel); err != nil {
|
||||
const peerCount = 1
|
||||
for i := range peerCount {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil {
|
||||
return fmt.Errorf("addPeer failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
c.sendResetSignal()
|
||||
ln, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen failed: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = c.runSOCKS5(runCtx, socksHost, socksPort, socksUser, socksPass, onReady)
|
||||
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
|
||||
|
||||
go c.acceptLoop(runCtx, ln)
|
||||
|
||||
<-runCtx.Done()
|
||||
c.shutdown()
|
||||
c.wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeKey(keyHex string) ([]byte, error) {
|
||||
if keyHex == "" {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("generate random key: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Generated key: %x", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode hex key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w: got %d", errInvalidKeyLength, len(key))
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendFrame(frame []byte) error {
|
||||
waitUntilPeersCanSend(c.peers)
|
||||
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt outgoing frame: %w", err)
|
||||
}
|
||||
|
||||
peer, err := c.nextPeer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := peer.Send(encrypted); err != nil {
|
||||
return fmt.Errorf("send frame via peer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitUntilPeersCanSend(peers []provider.Provider) {
|
||||
for {
|
||||
canSend := true
|
||||
for _, peer := range peers {
|
||||
if !peer.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if canSend {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, ErrKeySize
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return nil, ErrKeyStringLength
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// nextPeer returns the next provider for load balancing.
|
||||
//
|
||||
//nolint:ireturn
|
||||
func (c *Client) nextPeer() (provider.Provider, error) {
|
||||
switch len(c.peers) {
|
||||
case 0:
|
||||
return nil, errNoConnectedPeers
|
||||
case 1:
|
||||
return c.peers[0], nil
|
||||
default:
|
||||
return c.peers[int(c.peerIdx.Add(1)%2)], nil
|
||||
}
|
||||
func (c *Client) setupMux() {
|
||||
c.mux = mux.New(c.clientID, func(frame []byte) error {
|
||||
for {
|
||||
canSend := true
|
||||
for _, peer := range c.peers {
|
||||
if !peer.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canSend {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(c.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
|
||||
return c.peers[idx].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) addPeer(
|
||||
runCtx context.Context,
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL string,
|
||||
peerID int,
|
||||
cancel context.CancelFunc,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
) error {
|
||||
peer, err := provider.New(runCtx, providerName, provider.Config{
|
||||
RoomURL: roomURL,
|
||||
Name: names.Generate(),
|
||||
OnData: c.onData,
|
||||
peer, err := provider.New(ctx, providerName, provider.Config{
|
||||
RoomURL: roomURL,
|
||||
Name: names.Generate(),
|
||||
OnData: c.onData,
|
||||
DNSServer: dnsServer,
|
||||
ProxyAddr: socksProxyAddr,
|
||||
ProxyPort: socksProxyPort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer %d: %w", peerID, err)
|
||||
return fmt.Errorf("failed to create peer: %w", err)
|
||||
}
|
||||
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
|
||||
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
c.onReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
c.peers = append(c.peers, peer)
|
||||
|
||||
log.Printf("Connecting peer %d to %s...", peerID, providerName)
|
||||
if err := peer.Connect(runCtx); err != nil {
|
||||
return fmt.Errorf("connect peer %d: %w", peerID, err)
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
c.handlePeerReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
|
||||
if err := peer.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect peer: %w", err)
|
||||
}
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
logger.Infof("Peer %d connected", peerID)
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
peer.WatchConnection(runCtx)
|
||||
peer.WatchConnection(ctx)
|
||||
}()
|
||||
|
||||
// Send initial reset to clean up any stale connections for this clientID on server
|
||||
if err := c.mux.SendClientReset(); err != nil {
|
||||
logger.Warnf("Failed to send initial client reset: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) onReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
|
||||
c.connMu.Lock()
|
||||
for sid, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(c.connections, sid)
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
if dc != nil {
|
||||
c.mux.UpdateSendFunc(c.sendFrame)
|
||||
c.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(c.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
|
||||
return c.peers[idx].Send(encrypted)
|
||||
})
|
||||
c.mux.Reset()
|
||||
|
||||
if err := c.mux.SendClientReset(); err != nil {
|
||||
logger.Warnf("Failed to send client reset after reconnect: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendResetSignal() {
|
||||
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
|
||||
encrypted, err := c.cipher.Encrypt(resetFrame)
|
||||
if err != nil {
|
||||
log.Printf("Failed to encrypt reset signal: %v", err)
|
||||
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
logger.Debugf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go c.handleSOCKS5(ctx, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if err := c.socks5Handshake(conn); err != nil {
|
||||
logger.Debugf("SOCKS5 handshake failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Send(encrypted); err != nil {
|
||||
log.Printf("Failed to send reset signal to server: %v", err)
|
||||
}
|
||||
addr, port, err := c.socks5Request(conn)
|
||||
if err != nil {
|
||||
logger.Debugf("SOCKS5 request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
|
||||
sid := c.mux.OpenStream()
|
||||
c.connMu.Lock()
|
||||
c.connections[sid] = conn
|
||||
c.connMu.Unlock()
|
||||
|
||||
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
|
||||
|
||||
req := map[string]any{
|
||||
"cmd": "connect",
|
||||
"addr": addr,
|
||||
"port": port,
|
||||
}
|
||||
reqData, _ := json.Marshal(req)
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
logger.Warnf("sid=%d send connect failed: %v", sid, err)
|
||||
return
|
||||
}
|
||||
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
select {
|
||||
case <-dataReady:
|
||||
resp := c.mux.ReadStream(sid)
|
||||
if len(resp) > 0 && resp[0] == 0x00 {
|
||||
if _, err := conn.Write(replySuccess()); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
|
||||
c.activeClients.Add(1)
|
||||
c.startStreamPump(ctx, sid, conn)
|
||||
c.pumpToMux(sid, conn)
|
||||
}
|
||||
|
||||
func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if buf[0] != 5 {
|
||||
return ErrInvalidSocks5
|
||||
}
|
||||
|
||||
methods := make([]byte, int(buf[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := conn.Write([]byte{5, 0})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if buf[0] != 5 || buf[1] != 1 {
|
||||
return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1])
|
||||
}
|
||||
|
||||
var addr string
|
||||
switch buf[3] {
|
||||
case 1: // IPv4
|
||||
ip := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = net.IP(ip).String()
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
domain := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, domain); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = string(domain)
|
||||
case 4: // IPv6
|
||||
ip := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = net.IP(ip).String()
|
||||
default:
|
||||
return "", 0, fmt.Errorf("unsupported address type: %d", buf[3])
|
||||
}
|
||||
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
port := int(binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
@@ -264,347 +388,100 @@ func (c *Client) onData(data []byte) {
|
||||
c.mux.HandleFrame(plaintext)
|
||||
}
|
||||
|
||||
func (c *Client) runSOCKS5(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
port int,
|
||||
username,
|
||||
password string,
|
||||
onReady func(),
|
||||
) error {
|
||||
if host == "" {
|
||||
host = defaultSOCKSListenHost
|
||||
}
|
||||
|
||||
listenAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
var lc net.ListenConfig
|
||||
listener, err := lc.Listen(ctx, "tcp", listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", listenAddr, err)
|
||||
}
|
||||
|
||||
log.Printf("SOCKS5 proxy listening on %s (auth=%v)", listenAddr, username != "")
|
||||
if onReady != nil {
|
||||
onReady()
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Debugf("SOCKS5 listener close error: %v", err)
|
||||
func (c *Client) shutdown() {
|
||||
c.connMu.Lock()
|
||||
for _, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
for i, peer := range c.peers {
|
||||
logger.Infof("closing peer %d", i)
|
||||
_ = peer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
|
||||
defer func() {
|
||||
c.activeClients.Add(-1)
|
||||
_ = c.mux.CloseStream(sid)
|
||||
c.connMu.Lock()
|
||||
delete(c.connections, sid)
|
||||
c.connMu.Unlock()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
totalSent := uint64(0)
|
||||
lastLog := time.Now()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.closePeers()
|
||||
return nil
|
||||
default:
|
||||
log.Printf("accept error: %v", err)
|
||||
continue
|
||||
if totalSent > 1024*1024 {
|
||||
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go c.handleSOCKS5(conn, username, password)
|
||||
}
|
||||
}
|
||||
for !c.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
func (c *Client) closePeers() {
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Close(); err != nil {
|
||||
logger.Debugf("Peer close error: %v", err)
|
||||
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
totalSent += uint64(n) //nolint:gosec
|
||||
if time.Since(lastLog) > 5*time.Second {
|
||||
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
|
||||
lastLog = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // SOCKS5 parsing is inherently stateful and mirrors the protocol handshake.
|
||||
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("SOCKS5 connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 513)
|
||||
if !readSOCKSVersionAndMethods(conn, buf) {
|
||||
return
|
||||
}
|
||||
|
||||
nmethods := buf[1]
|
||||
if _, err := io.ReadFull(conn, buf[:nmethods]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
requireAuth := username != ""
|
||||
wantMethod := byte(0x00)
|
||||
if requireAuth {
|
||||
wantMethod = 0x02
|
||||
}
|
||||
|
||||
if !supportsMethod(buf[:nmethods], wantMethod) {
|
||||
writeResponse(conn, replyUnsupportedSOCKSMethod())
|
||||
return
|
||||
}
|
||||
writeResponse(conn, []byte{5, wantMethod})
|
||||
|
||||
if requireAuth && !authenticateSOCKSUser(conn, buf, username, password) {
|
||||
return
|
||||
}
|
||||
|
||||
addr, port, ok := readConnectTarget(conn, buf)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sid := c.mux.OpenStream()
|
||||
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
|
||||
log.Printf("sid=%d socks5 %s:%d", sid, addr, port)
|
||||
|
||||
if !c.sendConnectRequest(sid, addr, port) {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.waitConnectResponse(conn, sid) {
|
||||
return
|
||||
}
|
||||
|
||||
c.mux.ReadStream(sid)
|
||||
writeResponse(conn, replySuccess())
|
||||
c.proxyStream(conn, sid)
|
||||
}
|
||||
|
||||
func readSOCKSVersionAndMethods(conn net.Conn, buf []byte) bool {
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return buf[0] == 5
|
||||
}
|
||||
|
||||
func supportsMethod(methods []byte, wantMethod byte) bool {
|
||||
for _, method := range methods {
|
||||
if method == wantMethod {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func authenticateSOCKSUser(conn net.Conn, buf []byte, username, password string) bool {
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return false
|
||||
}
|
||||
if buf[0] != 0x01 {
|
||||
return false
|
||||
}
|
||||
|
||||
ulen := int(buf[1])
|
||||
if _, err := io.ReadFull(conn, buf[:ulen+1]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gotUser := string(buf[:ulen])
|
||||
plen := int(buf[ulen])
|
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gotPass := string(buf[:plen])
|
||||
if gotUser != username || gotPass != password {
|
||||
writeResponse(conn, replyAuthFailed())
|
||||
return false
|
||||
}
|
||||
|
||||
writeResponse(conn, replyAuthOK())
|
||||
return true
|
||||
}
|
||||
|
||||
func readConnectTarget(conn net.Conn, buf []byte) (string, uint16, bool) {
|
||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if buf[1] != 1 {
|
||||
writeResponse(conn, replyCommandNotSupported())
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
addr, ok := readTargetAddress(conn, buf, buf[3])
|
||||
if !ok {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
return addr, binary.BigEndian.Uint16(buf[:2]), true
|
||||
}
|
||||
|
||||
func readTargetAddress(conn net.Conn, buf []byte, atyp byte) (string, bool) {
|
||||
switch atyp {
|
||||
case 1:
|
||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]), true
|
||||
case 3:
|
||||
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
length := buf[0]
|
||||
if _, err := io.ReadFull(conn, buf[:length]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(buf[:length]), true
|
||||
default:
|
||||
writeResponse(conn, replyAddressNotSupported())
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
|
||||
reqData, err := json.Marshal(struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Addr string `json:"addr"`
|
||||
Port uint16 `json:"port"`
|
||||
}{
|
||||
Cmd: "connect",
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Debugf("Connect request marshal error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
logger.Debugf("Connect request send error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) waitConnectResponse(conn net.Conn, sid uint16) bool {
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
timeout := time.NewTimer(10 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-dataReady:
|
||||
stream := c.mux.GetStream(sid)
|
||||
if stream == nil || len(stream.RecvBuf()) == 0 {
|
||||
writeResponse(conn, replyHostUnreachable())
|
||||
return false
|
||||
}
|
||||
case <-timeout.C:
|
||||
writeResponse(conn, replyHostUnreachable())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
//nolint:cyclop // The stream pump handles two coordinated goroutines and shutdown races in one place.
|
||||
func (c *Client) proxyStream(conn net.Conn, sid uint16) {
|
||||
done := make(chan struct{})
|
||||
streamClosed := make(chan struct{})
|
||||
|
||||
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 32768)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err := c.mux.CloseStream(sid); err != nil {
|
||||
logger.Debugf("Close stream error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer close(streamClosed)
|
||||
defer c.mux.CleanupDataChannel(sid)
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
data := c.mux.ReadStream(sid)
|
||||
if len(data) > 0 && !writeStreamData(conn, data) {
|
||||
return
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
_ = c.mux.CloseStream(sid)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.mux.StreamClosed(sid) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-streamClosed:
|
||||
}
|
||||
}
|
||||
|
||||
func writeStreamData(conn net.Conn, data []byte) bool {
|
||||
for len(data) > 0 {
|
||||
n, err := conn.Write(data)
|
||||
if err != nil {
|
||||
func (c *Client) canSendData() bool {
|
||||
for _, peer := range c.peers {
|
||||
if !peer.CanSend() {
|
||||
return false
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func writeResponse(conn net.Conn, response []byte) {
|
||||
if _, err := conn.Write(response); err != nil {
|
||||
logger.Debugf("SOCKS5 response write error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func replyUnsupportedSOCKSMethod() []byte {
|
||||
return []byte{5, 0xFF}
|
||||
}
|
||||
|
||||
func replyAuthFailed() []byte {
|
||||
return []byte{0x01, 0x01}
|
||||
}
|
||||
|
||||
func replyAuthOK() []byte {
|
||||
return []byte{0x01, 0x00}
|
||||
}
|
||||
|
||||
func replyCommandNotSupported() []byte {
|
||||
return []byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyAddressNotSupported() []byte {
|
||||
return []byte{5, 8, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyHostUnreachable() []byte {
|
||||
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replySuccess() []byte {
|
||||
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyHostUnreachable() []byte {
|
||||
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user