Files
olcrtc/internal/client/client.go

488 lines
10 KiB
Go

// Package client implements the local SOCKS5 client side of the olcrtc tunnel.
package client
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
)
var (
ErrKeySize = errors.New("key must be 32 bytes")
ErrKeyStringLength = errors.New("key string length must be 32")
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
ErrNoPeers = errors.New("no peers available")
ErrEncryptFailed = errors.New("encrypt failed")
)
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
type Client struct {
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
}
// Run starts the client with the specified parameters.
func Run(
ctx context.Context,
providerName,
roomURL,
keyHex string,
localAddr string,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(keyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
clientIDBytes := make([]byte, 4)
if _, err := rand.Read(clientIDBytes); err != nil {
return fmt.Errorf("failed to generate client ID: %w", err)
}
clientID := binary.BigEndian.Uint32(clientIDBytes)
c := &Client{
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]provider.Provider, 0),
clientID: clientID,
dnsServer: dnsServer,
}
c.setupMux()
const peerCount = 1
for i := range peerCount {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
}
}
ln, err := net.Listen("tcp", localAddr)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer ln.Close()
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
go c.acceptLoop(runCtx, ln)
<-runCtx.Done()
c.shutdown()
c.wg.Wait()
return nil
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, ErrKeySize
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, ErrKeyStringLength
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (c *Client) setupMux() {
c.mux = mux.New(c.clientID, func(frame []byte) error {
for {
canSend := true
for _, peer := range c.peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
}
func (c *Client) addPeer(
ctx context.Context,
providerName,
roomURL string,
peerID int,
cancel context.CancelFunc,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
peer, err := provider.New(ctx, providerName, provider.Config{
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
DNSServer: dnsServer,
ProxyAddr: socksProxyAddr,
ProxyPort: socksProxyPort,
})
if err != nil {
return fmt.Errorf("failed to create peer: %w", err)
}
peer.SetEndedCallback(func(reason string) {
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
cancel()
})
c.peers = append(c.peers, peer)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.handlePeerReconnect(peerID, dc)
})
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
}
logger.Infof("Peer %d connected", peerID)
c.wg.Add(1)
go func() {
defer c.wg.Done()
peer.WatchConnection(ctx)
}()
// Send initial reset to clean up any stale connections for this clientID on server
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send initial client reset: %v", err)
}
return nil
}
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
c.connMu.Lock()
for sid, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
delete(c.connections, sid)
}
c.connMu.Unlock()
if dc != nil {
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
c.mux.Reset()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after reconnect: %v", err)
}
}
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
select {
case <-ctx.Done():
return
default:
conn, err := ln.Accept()
if err != nil {
logger.Debugf("Accept error: %v", err)
continue
}
go c.handleSOCKS5(ctx, conn)
}
}
}
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
defer conn.Close()
if err := c.socks5Handshake(conn); err != nil {
logger.Debugf("SOCKS5 handshake failed: %v", err)
return
}
addr, port, err := c.socks5Request(conn)
if err != nil {
logger.Debugf("SOCKS5 request failed: %v", err)
return
}
sid := c.mux.OpenStream()
c.connMu.Lock()
c.connections[sid] = conn
c.connMu.Unlock()
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
req := map[string]any{
"cmd": "connect",
"addr": addr,
"port": port,
}
reqData, _ := json.Marshal(req)
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Warnf("sid=%d send connect failed: %v", sid, err)
return
}
dataReady := c.mux.WaitForData(sid)
select {
case <-dataReady:
resp := c.mux.ReadStream(sid)
if len(resp) > 0 && resp[0] == 0x00 {
if _, err := conn.Write(replySuccess()); err != nil {
return
}
} else {
_, _ = conn.Write(replyHostUnreachable())
return
}
case <-time.After(15 * time.Second):
_, _ = conn.Write(replyHostUnreachable())
c.mux.CleanupDataChannel(sid)
return
case <-ctx.Done():
return
}
c.mux.CleanupDataChannel(sid)
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return err
}
if buf[0] != 5 {
return ErrInvalidSocks5
}
methods := make([]byte, int(buf[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
_, err := conn.Write([]byte{5, 0})
return err
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
if buf[0] != 5 || buf[1] != 1 {
return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1])
}
var addr string
switch buf[3] {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", 0, err
}
addr = string(domain)
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
default:
return "", 0, fmt.Errorf("unsupported address type: %d", buf[3])
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, err
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {
logger.Debugf("Decrypt error: %v", err)
return
}
c.mux.HandleFrame(plaintext)
}
func (c *Client) shutdown() {
c.connMu.Lock()
for _, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
}
c.connMu.Unlock()
for i, peer := range c.peers {
logger.Infof("closing peer %d", i)
_ = peer.Close()
}
}
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
c.activeClients.Add(-1)
_ = c.mux.CloseStream(sid)
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
for !c.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
}
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
data := c.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = c.mux.CloseStream(sid)
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}()
}
func (c *Client) canSendData() bool {
for _, peer := range c.peers {
if !peer.CanSend() {
return false
}
}
return true
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}