Files
olcrtc/internal/client/client.go
2026-05-16 02:10:33 +03:00

741 lines
20 KiB
Go

// Package client implements the local SOCKS5 client side of the olcrtc tunnel.
package client
import (
"context"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/xtaci/smux"
)
var (
// ErrConnectFailed is returned when a tunnel connection fails.
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command")
// ErrUnsupportedAddressType is returned for unsupported SOCKS address types.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrRemoteNotReady is returned when the server-side stream fails to signal readiness.
ErrRemoteNotReady = errors.New("remote not ready")
// ErrSOCKSAuthFailed is returned when username/password authentication is rejected.
ErrSOCKSAuthFailed = errors.New("SOCKS5 authentication failed")
// ErrSOCKSCredTooLong is returned when a SOCKS5 username or password exceeds 255 bytes.
ErrSOCKSCredTooLong = errors.New("socks5 user/pass exceeds 255 bytes")
)
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
reconnectMu sync.Mutex
deviceID string
sessionID string
claims map[string]any
dnsServer string
socksUser string
socksPass string
}
// Config holds runtime configuration for [Run] and [RunWithReady].
type Config struct {
Link string
Transport string
Carrier string
RoomURL string
KeyHex string
LocalAddr string
DNSServer string
SOCKSUser string
SOCKSPass string
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Engine string
URL string
Token string
Liveness control.Config
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
// are empty).
DeviceID string
// DeviceIDPath is a file in which to persist the auto-generated device ID
// across restarts. Ignored when DeviceID is set explicitly.
DeviceIDPath string
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
}
// Run starts the client with the given configuration.
func Run(ctx context.Context, cfg Config) error {
return RunWithReady(ctx, cfg, nil)
}
// RunWithReady is like Run but invokes onReady once the local SOCKS listener is up.
func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(cfg.KeyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath)
if err != nil {
return fmt.Errorf("resolve device id: %w", err)
}
c := &Client{
cipher: cipher,
deviceID: deviceID,
claims: cfg.Claims,
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
}
// shutdown is registered BEFORE bringUpLink so we always close any
// link/session that bringUpLink managed to set up before it
// errored out. The previous ordering returned early on failure
// (e.g. handshake timeout against a wedged seichannel transport)
// without ever calling Close on the carrier link, leaving our MUC
// presence behind as a ghost participant in the next test that
// joined the same room. shutdown is nil-safe — it skips fields
// that bringUpLink hadn't populated yet.
defer c.shutdown()
if err := c.bringUpLink(runCtx, cfg, cancel); err != nil {
return err
}
lc := net.ListenConfig{}
listener, err := lc.Listen(runCtx, "tcp4", cfg.LocalAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", cfg.LocalAddr, err)
}
defer func() { _ = listener.Close() }()
logger.Infof("SOCKS5 server listening on %s", cfg.LocalAddr)
if onReady != nil {
onReady()
}
go c.acceptLoop(runCtx, listener)
<-runCtx.Done()
return nil
}
func (c *Client) bringUpLink(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
) error {
ln, err := link.New(ctx, cfg.Link, link.Config{
Transport: cfg.Transport,
Carrier: cfg.Carrier,
RoomURL: cfg.RoomURL,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
DeviceID: c.deviceID,
Name: names.Generate(),
OnData: c.onData,
DNSServer: cfg.DNSServer,
VideoWidth: cfg.VideoWidth,
VideoHeight: cfg.VideoHeight,
VideoFPS: cfg.VideoFPS,
VideoBitrate: cfg.VideoBitrate,
VideoHW: cfg.VideoHW,
VideoQRSize: cfg.VideoQRSize,
VideoQRRecovery: cfg.VideoQRRecovery,
VideoCodec: cfg.VideoCodec,
VideoTileModule: cfg.VideoTileModule,
VideoTileRS: cfg.VideoTileRS,
VP8FPS: cfg.VP8FPS,
VP8BatchSize: cfg.VP8BatchSize,
SEIFPS: cfg.SEIFPS,
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
c.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Client link reported conference end: %s", reason)
cancel()
})
ln.SetShouldReconnect(func() bool { return ctx.Err() == nil })
ln.SetReconnectCallback(func() {
if ctx.Err() != nil {
return
}
if !c.handleReconnect(ctx, cfg, cancel) {
cancel()
}
})
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if err != nil {
_ = sess.Close()
_ = c.conn.Close()
return fmt.Errorf("handshake: %w", err)
}
logger.Infof("session %s opened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.startControlLoop(ctx, cfg, cancel, control)
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session and carries
// post-handshake control messages.
func openControlStream(
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout)
}
func openControlStreamTimeout(
sess *smux.Session,
deviceID string,
claims map[string]any,
timeout time.Duration,
) (*smux.Stream, string, error) {
stream, err := sess.OpenStream()
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
_ = stream.SetDeadline(time.Now().Add(timeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
_ = stream.Close()
return nil, "", fmt.Errorf("handshake client: %w", err)
}
return stream, sid, nil
}
// resolveDeviceID returns the device ID to send in CLIENT_HELLO.
//
// Precedence:
// 1. Explicit deviceID arg (Config.DeviceID) — used verbatim.
// 2. Persistent file at path (Config.DeviceIDPath) — read if it exists,
// otherwise generated and written for future runs.
// 3. Random UUID per run when both inputs are empty.
func resolveDeviceID(deviceID, path string) (string, error) {
if deviceID != "" {
return deviceID, nil
}
if path == "" {
return uuid.NewString(), nil
}
// #nosec G304 -- persistent device ID path is explicit user configuration.
data, err := os.ReadFile(path)
if err == nil {
id := strings.TrimSpace(string(data))
if id != "" {
return id, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("read device id %s: %w", path, err)
}
id := uuid.NewString()
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
return "", fmt.Errorf("mkdir device id dir: %w", err)
}
if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil {
return "", fmt.Errorf("write device id %s: %w", path, err)
}
return id, nil
}
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
cfg.KeepAliveTimeout = 60 * time.Second
return cfg
}
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc) bool {
c.reconnectMu.Lock()
defer c.reconnectMu.Unlock()
logger.Infof("client link reconnect - tearing down smux session")
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
// again with its own conn on each attempt.
newConn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
oldControl := c.controlStrm
oldControlStop := c.controlStop
oldSess := c.session
oldConn := c.conn
c.conn = newConn
c.session = nil
c.controlStrm = nil
c.controlStop = nil
c.sessionID = ""
c.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldControl != nil {
_ = oldControl.Close()
}
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
// Server-side may still be tearing down its own session when our callback
// fires — carriers don't guarantee reconnect callbacks are delivered to both
// peers atomically. Retry the handshake a few times, building a fresh
// muxconn+smux pair on each attempt so a failed smux.Close doesn't corrupt
// the byte stream for subsequent attempts.
const (
maxAttempts = 5
attemptDelay = 300 * time.Millisecond
)
for attempt := 1; attempt <= maxAttempts; attempt++ {
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(attemptDelay):
}
}
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
return false
}
func (c *Client) tryReopenSession(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
attempt int,
) bool {
conn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
old := c.conn
c.conn = conn
c.sessMu.Unlock()
if old != nil {
_ = old.Close()
}
sess, err := smux.Client(conn, smuxConfig())
if err != nil {
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
}
control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second)
if err != nil {
logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err)
_ = sess.Close()
return false
}
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.startControlLoop(ctx, cfg, cancel, control)
return true
}
func (c *Client) startControlLoop(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
stream *smux.Stream,
) {
controlCtx, stop := context.WithCancel(ctx)
c.sessMu.Lock()
c.controlStop = stop
c.sessMu.Unlock()
liveness := cfg.Liveness
onPong := liveness.OnPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
c.sessMu.RLock()
sid := c.sessionID
c.sessMu.RUnlock()
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnUnhealthy = func(missed int) {
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
go func() {
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("client control stream ended: %v", err)
}
if !c.handleReconnect(ctx, cfg, cancel) {
cancel()
}
}()
}
func (c *Client) shutdown() {
c.sessMu.Lock()
control := c.controlStrm
controlStop := c.controlStop
sess := c.session
conn := c.conn
c.controlStrm = nil
c.controlStop = nil
c.session = nil
c.conn = nil
c.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if conn != nil {
_ = conn.Close()
}
if c.ln != nil {
_ = c.ln.Close()
}
if control != nil {
_ = control.Close()
}
if sess != nil {
_ = sess.Close()
}
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key))
}
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (c *Client) onData(data []byte) {
c.sessMu.RLock()
conn := c.conn
c.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
logger.Warnf("Accept error: %v", err)
continue
}
}
go c.handleSocks5(ctx, conn)
}
}
func (c *Client) handleSocks5(_ context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
return
}
targetAddr, targetPort, err := c.socks5Request(conn)
if err != nil {
return
}
c.sessMu.RLock()
sess := c.session
c.sessMu.RUnlock()
if sess == nil || sess.IsClosed() {
_, _ = conn.Write(replyHostUnreachable())
return
}
c.tunnel(conn, sess, targetAddr, targetPort)
}
func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) {
stream, err := sess.OpenStream()
if err != nil {
logger.Warnf("OpenStream failed: %v", err)
_, _ = conn.Write(replyHostUnreachable())
return
}
defer func() { _ = stream.Close() }()
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil {
logger.Warnf("sid=%d connect failed: %v", stream.ID(), err)
_, _ = conn.Write(replyHostUnreachable())
return
}
if _, err := conn.Write(replySuccess()); err != nil {
return
}
go func() {
_, _ = io.Copy(stream, conn)
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
}
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
connectReq, err := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err != nil {
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Time{})
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack)
}
_ = stream.SetReadDeadline(time.Time{})
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("read socks5 header: %w", err)
}
if buf[0] != 5 {
return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0])
}
methods := make([]byte, buf[1])
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("read socks5 methods: %w", err)
}
if c.socksUser != "" {
// RFC 1929: method 0x02 = username/password auth.
if _, err := conn.Write([]byte{5, 2}); err != nil {
return fmt.Errorf("write socks5 auth method: %w", err)
}
if err := c.socks5UserPassAuth(conn); err != nil {
return err
}
return nil
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write socks5 auth: %w", err)
}
return nil
}
func (c *Client) socks5UserPassAuth(conn net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return fmt.Errorf("read socks5 auth header: %w", err)
}
if header[0] != 0x01 {
return fmt.Errorf("%w: expected auth version 1, got %d", ErrInvalidSOCKSVersion, header[0])
}
ulen := int(header[1])
userBuf := make([]byte, ulen)
if _, err := io.ReadFull(conn, userBuf); err != nil {
return fmt.Errorf("read socks5 username: %w", err)
}
plenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, plenBuf); err != nil {
return fmt.Errorf("read socks5 plen: %w", err)
}
plen := int(plenBuf[0])
passBuf := make([]byte, plen)
if _, err := io.ReadFull(conn, passBuf); err != nil {
return fmt.Errorf("read socks5 password: %w", err)
}
if string(userBuf) != c.socksUser || string(passBuf) != c.socksPass {
_, _ = conn.Write([]byte{0x01, 0x01})
return ErrSOCKSAuthFailed
}
if _, err := conn.Write([]byte{0x01, 0x00}); err != nil {
return fmt.Errorf("write socks5 auth success: %w", err)
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0, fmt.Errorf("read socks5 request: %w", err)
}
if header[1] != 1 {
return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1])
}
addr, err := c.readSocks5Addr(conn, header[3])
if err != nil {
return "", 0, err
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, fmt.Errorf("read socks5 port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 ipv4: %w", err)
}
return net.IP(buf).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read socks5 domain len: %w", err)
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 domain: %w", err)
}
return string(buf), nil
default:
return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType)
}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}