mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 23:19:47 +00:00
535 lines
12 KiB
Go
535 lines
12 KiB
Go
// Package client implements the local SOCKS5 client side of the olcrtc tunnel.
|
|
package client
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
|
"github.com/openlibrecommunity/olcrtc/internal/link"
|
|
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
|
"github.com/openlibrecommunity/olcrtc/internal/mux"
|
|
"github.com/openlibrecommunity/olcrtc/internal/names"
|
|
)
|
|
|
|
var (
|
|
// ErrKeySize is returned when the key size is not 32 bytes.
|
|
ErrKeySize = errors.New("key must be 32 bytes")
|
|
// ErrKeyStringLength is returned when the key string length is not 32.
|
|
ErrKeyStringLength = errors.New("key string length must be 32")
|
|
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
|
|
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
|
|
// ErrNoPeers is returned when no peers are available for sending.
|
|
ErrNoPeers = errors.New("no peers available")
|
|
// ErrEncryptFailed is returned when encryption fails.
|
|
ErrEncryptFailed = errors.New("encrypt failed")
|
|
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
|
|
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
|
|
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
|
|
ErrUnsupportedAddressType = errors.New("unsupported address type")
|
|
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
|
|
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
|
|
)
|
|
|
|
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
|
|
type Client struct {
|
|
links []link.Link
|
|
cipher *crypto.Cipher
|
|
mux *mux.Multiplexer
|
|
connections map[uint16]net.Conn
|
|
connMu sync.RWMutex
|
|
peerIdx atomic.Uint32
|
|
clientID uint32
|
|
activeClients atomic.Int32
|
|
wg sync.WaitGroup
|
|
dnsServer string
|
|
}
|
|
|
|
// Run starts the client with the specified parameters.
|
|
func Run(
|
|
ctx context.Context,
|
|
transportName,
|
|
providerName,
|
|
roomURL,
|
|
keyHex string,
|
|
localAddr string,
|
|
dnsServer,
|
|
socksUser string,
|
|
socksPass string,
|
|
) error {
|
|
return RunWithReady(ctx, transportName, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
|
|
}
|
|
|
|
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
|
|
func RunWithReady(
|
|
ctx context.Context,
|
|
transportName,
|
|
providerName,
|
|
roomURL,
|
|
keyHex string,
|
|
localAddr string,
|
|
dnsServer,
|
|
_ string,
|
|
_ string,
|
|
onReady func(),
|
|
) error {
|
|
runCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
cipher, err := setupCipher(keyHex)
|
|
if err != nil {
|
|
return fmt.Errorf("setupCipher failed: %w", err)
|
|
}
|
|
|
|
clientIDBytes := make([]byte, 4)
|
|
if _, err := rand.Read(clientIDBytes); err != nil {
|
|
return fmt.Errorf("failed to generate client ID: %w", err)
|
|
}
|
|
clientID := binary.BigEndian.Uint32(clientIDBytes)
|
|
|
|
c := &Client{
|
|
cipher: cipher,
|
|
connections: make(map[uint16]net.Conn),
|
|
links: make([]link.Link, 0),
|
|
clientID: clientID,
|
|
dnsServer: dnsServer,
|
|
}
|
|
|
|
c.setupMux()
|
|
|
|
const peerCount = 1
|
|
for i := range peerCount {
|
|
if err := c.addTransport(runCtx, transportName, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
|
|
return fmt.Errorf("addTransport failed: %w", err)
|
|
}
|
|
}
|
|
|
|
lc := net.ListenConfig{}
|
|
ln, err := lc.Listen(runCtx, "tcp", localAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen failed: %w", err)
|
|
}
|
|
defer func() { _ = ln.Close() }()
|
|
|
|
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
|
|
|
|
if onReady != nil {
|
|
onReady()
|
|
}
|
|
|
|
go c.acceptLoop(runCtx, ln)
|
|
|
|
<-runCtx.Done()
|
|
c.shutdown()
|
|
c.wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
|
key, err := hex.DecodeString(keyHex)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode key: %w", err)
|
|
}
|
|
if len(key) != 32 {
|
|
return nil, ErrKeySize
|
|
}
|
|
|
|
keyStr := string(key)
|
|
if len(keyStr) != 32 {
|
|
return nil, ErrKeyStringLength
|
|
}
|
|
|
|
cipher, err := crypto.NewCipher(keyStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
return cipher, nil
|
|
}
|
|
|
|
func (c *Client) setupMux() {
|
|
c.mux = mux.New(c.clientID, func(frame []byte) error {
|
|
for {
|
|
canSend := true
|
|
for _, ln := range c.links {
|
|
if !ln.CanSend() {
|
|
canSend = false
|
|
break
|
|
}
|
|
}
|
|
if canSend {
|
|
break
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
encrypted, err := c.cipher.Encrypt(frame)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
|
}
|
|
if len(c.links) == 0 {
|
|
return ErrNoPeers
|
|
}
|
|
idx := c.peerIdx.Add(1) % uint32(len(c.links)) //nolint:gosec
|
|
return c.links[idx].Send(encrypted)
|
|
})
|
|
}
|
|
|
|
func (c *Client) addTransport(
|
|
ctx context.Context,
|
|
transportName,
|
|
providerName,
|
|
roomURL string,
|
|
peerID int,
|
|
cancel context.CancelFunc,
|
|
dnsServer,
|
|
socksProxyAddr string,
|
|
socksProxyPort int,
|
|
) error {
|
|
ln, err := link.New(ctx, "direct", link.Config{
|
|
Transport: transportName,
|
|
Carrier: providerName,
|
|
RoomURL: roomURL,
|
|
Name: names.Generate(),
|
|
OnData: c.onData,
|
|
DNSServer: dnsServer,
|
|
ProxyAddr: socksProxyAddr,
|
|
ProxyPort: socksProxyPort,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create link: %w", err)
|
|
}
|
|
|
|
ln.SetEndedCallback(func(reason string) {
|
|
logger.Infof("Client transport %d reported conference end: %s", peerID, reason)
|
|
cancel()
|
|
})
|
|
c.links = append(c.links, ln)
|
|
|
|
ln.SetReconnectCallback(func() {
|
|
c.handleTransportReconnect(peerID)
|
|
})
|
|
|
|
logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName)
|
|
if err := ln.Connect(ctx); err != nil {
|
|
return fmt.Errorf("failed to connect transport: %w", err)
|
|
}
|
|
logger.Infof("Transport %d connected", peerID)
|
|
|
|
c.wg.Add(1)
|
|
go func() {
|
|
defer c.wg.Done()
|
|
ln.WatchConnection(ctx)
|
|
}()
|
|
|
|
// Send initial reset to clean up any stale connections for this clientID on server
|
|
if err := c.mux.SendClientReset(); err != nil {
|
|
logger.Warnf("Failed to send initial client reset: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) handleTransportReconnect(peerID int) {
|
|
logger.Infof("transport %d reconnect event", peerID)
|
|
|
|
c.connMu.Lock()
|
|
for sid, conn := range c.connections {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
delete(c.connections, sid)
|
|
}
|
|
c.connMu.Unlock()
|
|
|
|
c.mux.UpdateSendFunc(func(frame []byte) error {
|
|
encrypted, err := c.cipher.Encrypt(frame)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
|
}
|
|
if len(c.links) == 0 {
|
|
return ErrNoPeers
|
|
}
|
|
idx := c.peerIdx.Add(1) % uint32(len(c.links)) //nolint:gosec
|
|
return c.links[idx].Send(encrypted)
|
|
})
|
|
c.mux.Reset()
|
|
|
|
if err := c.mux.SendClientReset(); err != nil {
|
|
logger.Warnf("Failed to send client reset after reconnect: %v", err)
|
|
}
|
|
}
|
|
|
|
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
logger.Debugf("Accept error: %v", err)
|
|
continue
|
|
}
|
|
go c.handleSOCKS5(ctx, conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
if err := c.socks5Handshake(conn); err != nil {
|
|
logger.Debugf("SOCKS5 handshake failed: %v", err)
|
|
return
|
|
}
|
|
|
|
addr, port, err := c.socks5Request(conn)
|
|
if err != nil {
|
|
logger.Debugf("SOCKS5 request failed: %v", err)
|
|
return
|
|
}
|
|
|
|
sid := c.mux.OpenStream()
|
|
c.connMu.Lock()
|
|
c.connections[sid] = conn
|
|
c.connMu.Unlock()
|
|
|
|
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
|
|
|
|
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
|
|
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
|
|
return
|
|
}
|
|
|
|
c.activeClients.Add(1)
|
|
c.startStreamPump(ctx, sid, conn)
|
|
c.pumpToMux(sid, conn)
|
|
}
|
|
|
|
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
|
|
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
|
|
reqData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal connect: %w", err)
|
|
}
|
|
|
|
if err := c.mux.SendData(sid, reqData); err != nil {
|
|
return fmt.Errorf("send connect: %w", err)
|
|
}
|
|
|
|
dataReady := c.mux.WaitForData(sid)
|
|
select {
|
|
case <-dataReady:
|
|
resp := c.mux.ReadStream(sid)
|
|
if len(resp) > 0 && resp[0] == 0x00 {
|
|
if _, err := conn.Write(replySuccess()); err != nil {
|
|
return fmt.Errorf("write success: %w", err)
|
|
}
|
|
} else {
|
|
_, _ = conn.Write(replyHostUnreachable())
|
|
return ErrTunnelSetupFailed
|
|
}
|
|
case <-time.After(15 * time.Second):
|
|
_, _ = conn.Write(replyHostUnreachable())
|
|
c.mux.CleanupDataChannel(sid)
|
|
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("context cancelled: %w", ctx.Err())
|
|
}
|
|
c.mux.CleanupDataChannel(sid)
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) socks5Handshake(conn net.Conn) error {
|
|
buf := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
return fmt.Errorf("read header: %w", err)
|
|
}
|
|
|
|
if buf[0] != 5 {
|
|
return ErrInvalidSocks5
|
|
}
|
|
|
|
methods := make([]byte, int(buf[1]))
|
|
if _, err := io.ReadFull(conn, methods); err != nil {
|
|
return fmt.Errorf("read methods: %w", err)
|
|
}
|
|
|
|
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
|
return fmt.Errorf("write response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
|
buf := make([]byte, 4)
|
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
return "", 0, fmt.Errorf("read request header: %w", err)
|
|
}
|
|
|
|
if buf[0] != 5 || buf[1] != 1 {
|
|
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
|
|
}
|
|
|
|
addr, err := c.readSocks5Addr(conn, buf[3])
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
portBuf := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
|
return "", 0, fmt.Errorf("read port: %w", err)
|
|
}
|
|
port := int(binary.BigEndian.Uint16(portBuf))
|
|
|
|
return addr, port, nil
|
|
}
|
|
|
|
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
|
|
switch addrType {
|
|
case 1: // IPv4
|
|
ip := make([]byte, 4)
|
|
if _, err := io.ReadFull(conn, ip); err != nil {
|
|
return "", fmt.Errorf("read ipv4: %w", err)
|
|
}
|
|
return net.IP(ip).String(), nil
|
|
case 3: // Domain
|
|
lenBuf := make([]byte, 1)
|
|
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
return "", fmt.Errorf("read domain len: %w", err)
|
|
}
|
|
domain := make([]byte, int(lenBuf[0]))
|
|
if _, err := io.ReadFull(conn, domain); err != nil {
|
|
return "", fmt.Errorf("read domain: %w", err)
|
|
}
|
|
return string(domain), nil
|
|
case 4: // IPv6
|
|
ip := make([]byte, 16)
|
|
if _, err := io.ReadFull(conn, ip); err != nil {
|
|
return "", fmt.Errorf("read ipv6: %w", err)
|
|
}
|
|
return net.IP(ip).String(), nil
|
|
default:
|
|
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
|
|
}
|
|
}
|
|
|
|
func (c *Client) onData(data []byte) {
|
|
plaintext, err := c.cipher.Decrypt(data)
|
|
if err != nil {
|
|
logger.Debugf("Decrypt error: %v", err)
|
|
return
|
|
}
|
|
|
|
c.mux.HandleFrame(plaintext)
|
|
}
|
|
|
|
func (c *Client) shutdown() {
|
|
c.connMu.Lock()
|
|
for _, conn := range c.connections {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
c.connMu.Unlock()
|
|
|
|
for i, tr := range c.links {
|
|
logger.Infof("closing transport %d", i)
|
|
_ = tr.Close()
|
|
}
|
|
}
|
|
|
|
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
|
|
defer func() {
|
|
c.activeClients.Add(-1)
|
|
_ = c.mux.CloseStream(sid)
|
|
c.connMu.Lock()
|
|
delete(c.connections, sid)
|
|
c.connMu.Unlock()
|
|
}()
|
|
|
|
buf := make([]byte, 16384)
|
|
totalSent := uint64(0)
|
|
lastLog := time.Now()
|
|
|
|
for {
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
if totalSent > 1024*1024 {
|
|
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
|
|
}
|
|
return
|
|
}
|
|
|
|
for !c.canSendData() {
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
|
|
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
|
return
|
|
}
|
|
|
|
totalSent += uint64(n) //nolint:gosec
|
|
if time.Since(lastLog) > 5*time.Second {
|
|
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
|
|
lastLog = time.Now()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
|
c.wg.Add(1)
|
|
go func() {
|
|
defer c.wg.Done()
|
|
|
|
ticker := time.NewTicker(10 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
data := c.mux.ReadStream(sid)
|
|
if len(data) > 0 {
|
|
if _, err := conn.Write(data); err != nil {
|
|
_ = c.mux.CloseStream(sid)
|
|
return
|
|
}
|
|
}
|
|
if c.mux.StreamClosed(sid) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (c *Client) canSendData() bool {
|
|
for _, tr := range c.links {
|
|
if !tr.CanSend() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func replySuccess() []byte {
|
|
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
|
|
}
|
|
|
|
func replyHostUnreachable() []byte {
|
|
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
|
|
}
|