Files
olcrtc/internal/client/client.go
2026-05-13 20:03:58 +03:00

594 lines
16 KiB
Go

// Package client implements the local SOCKS5 client side of the olcrtc tunnel.
package client
import (
"context"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/xtaci/smux"
)
var (
// ErrConnectFailed is returned when a tunnel connection fails.
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command")
// ErrUnsupportedAddressType is returned for unsupported SOCKS address types.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrRemoteNotReady is returned when the server-side stream fails to signal readiness.
ErrRemoteNotReady = errors.New("remote not ready")
// ErrSOCKSAuthFailed is returned when username/password authentication is rejected.
ErrSOCKSAuthFailed = errors.New("SOCKS5 authentication failed")
// ErrSOCKSCredTooLong is returned when a SOCKS5 username or password exceeds 255 bytes.
ErrSOCKSCredTooLong = errors.New("socks5 user/pass exceeds 255 bytes")
)
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
sessMu sync.RWMutex
deviceID string
sessionID string
claims map[string]any
dnsServer string
socksUser string
socksPass string
}
// Config holds runtime configuration for [Run] and [RunWithReady].
type Config struct {
Link string
Transport string
Carrier string
RoomURL string
KeyHex string
LocalAddr string
DNSServer string
SOCKSUser string
SOCKSPass string
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Engine string
URL string
Token string
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
// are empty).
DeviceID string
// DeviceIDPath is a file in which to persist the auto-generated device ID
// across restarts. Ignored when DeviceID is set explicitly.
DeviceIDPath string
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
}
// Run starts the client with the given configuration.
func Run(ctx context.Context, cfg Config) error {
return RunWithReady(ctx, cfg, nil)
}
// RunWithReady is like Run but invokes onReady once the local SOCKS listener is up.
func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(cfg.KeyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath)
if err != nil {
return fmt.Errorf("resolve device id: %w", err)
}
c := &Client{
cipher: cipher,
deviceID: deviceID,
claims: cfg.Claims,
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
}
if err := c.bringUpLink(runCtx, cfg, cancel); err != nil {
return err
}
defer c.shutdown()
lc := net.ListenConfig{}
listener, err := lc.Listen(runCtx, "tcp4", cfg.LocalAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", cfg.LocalAddr, err)
}
defer func() { _ = listener.Close() }()
logger.Infof("SOCKS5 server listening on %s", cfg.LocalAddr)
if onReady != nil {
onReady()
}
go c.acceptLoop(runCtx, listener)
<-runCtx.Done()
return nil
}
func (c *Client) bringUpLink(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
) error {
ln, err := link.New(ctx, cfg.Link, link.Config{
Transport: cfg.Transport,
Carrier: cfg.Carrier,
RoomURL: cfg.RoomURL,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
DeviceID: c.deviceID,
Name: names.Generate(),
OnData: c.onData,
DNSServer: cfg.DNSServer,
VideoWidth: cfg.VideoWidth,
VideoHeight: cfg.VideoHeight,
VideoFPS: cfg.VideoFPS,
VideoBitrate: cfg.VideoBitrate,
VideoHW: cfg.VideoHW,
VideoQRSize: cfg.VideoQRSize,
VideoQRRecovery: cfg.VideoQRRecovery,
VideoCodec: cfg.VideoCodec,
VideoTileModule: cfg.VideoTileModule,
VideoTileRS: cfg.VideoTileRS,
VP8FPS: cfg.VP8FPS,
VP8BatchSize: cfg.VP8BatchSize,
SEIFPS: cfg.SEIFPS,
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
c.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Client link reported conference end: %s", reason)
cancel()
})
ln.SetReconnectCallback(func() { c.handleReconnect() })
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if err != nil {
_ = sess.Close()
_ = c.conn.Close()
return fmt.Errorf("handshake: %w", err)
}
logger.Infof("session %s opened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session — the server
// holds it parked, and it would carry future control messages.
func openControlStream(
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
stream, err := sess.OpenStream()
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
_ = stream.Close()
return nil, "", err
}
return stream, sid, nil
}
// resolveDeviceID returns the device ID to send in CLIENT_HELLO.
//
// Precedence:
// 1. Explicit deviceID arg (Config.DeviceID) — used verbatim.
// 2. Persistent file at path (Config.DeviceIDPath) — read if it exists,
// otherwise generated and written for future runs.
// 3. Random UUID per run when both inputs are empty.
func resolveDeviceID(deviceID, path string) (string, error) {
if deviceID != "" {
return deviceID, nil
}
if path == "" {
return uuid.NewString(), nil
}
data, err := os.ReadFile(path)
if err == nil {
id := strings.TrimSpace(string(data))
if id != "" {
return id, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("read device id %s: %w", path, err)
}
id := uuid.NewString()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", fmt.Errorf("mkdir device id dir: %w", err)
}
if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil {
return "", fmt.Errorf("write device id %s: %w", path, err)
}
return id, nil
}
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
cfg.KeepAliveTimeout = 60 * time.Second
return cfg
}
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect - tearing down smux session")
c.sessMu.Lock()
if c.controlStrm != nil {
_ = c.controlStrm.Close()
c.controlStrm = nil
}
if c.session != nil {
_ = c.session.Close()
c.session = nil
}
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
c.sessionID = ""
c.sessMu.Unlock()
c.conn = muxconn.New(c.ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
logger.Warnf("smux re-init failed: %v", err)
return
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if err != nil {
logger.Warnf("handshake on reconnect failed: %v", err)
_ = sess.Close()
return
}
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
}
func (c *Client) shutdown() {
c.sessMu.Lock()
if c.controlStrm != nil {
_ = c.controlStrm.Close()
}
if c.session != nil {
_ = c.session.Close()
}
if c.conn != nil {
_ = c.conn.Close()
}
c.sessMu.Unlock()
if c.ln != nil {
_ = c.ln.Close()
}
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key))
}
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (c *Client) onData(data []byte) {
c.sessMu.RLock()
conn := c.conn
c.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
logger.Warnf("Accept error: %v", err)
continue
}
}
go c.handleSocks5(ctx, conn)
}
}
func (c *Client) handleSocks5(_ context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
return
}
targetAddr, targetPort, err := c.socks5Request(conn)
if err != nil {
return
}
c.sessMu.RLock()
sess := c.session
c.sessMu.RUnlock()
if sess == nil || sess.IsClosed() {
_, _ = conn.Write(replyHostUnreachable())
return
}
c.tunnel(conn, sess, targetAddr, targetPort)
}
func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) {
stream, err := sess.OpenStream()
if err != nil {
logger.Warnf("OpenStream failed: %v", err)
_, _ = conn.Write(replyHostUnreachable())
return
}
defer func() { _ = stream.Close() }()
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil {
logger.Warnf("sid=%d connect failed: %v", stream.ID(), err)
_, _ = conn.Write(replyHostUnreachable())
return
}
if _, err := conn.Write(replySuccess()); err != nil {
return
}
go func() {
_, _ = io.Copy(stream, conn)
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
}
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
connectReq, err := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err != nil {
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Time{})
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack)
}
_ = stream.SetReadDeadline(time.Time{})
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("read socks5 header: %w", err)
}
if buf[0] != 5 {
return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0])
}
methods := make([]byte, buf[1])
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("read socks5 methods: %w", err)
}
if c.socksUser != "" {
// RFC 1929: method 0x02 = username/password auth.
if _, err := conn.Write([]byte{5, 2}); err != nil {
return fmt.Errorf("write socks5 auth method: %w", err)
}
if err := c.socks5UserPassAuth(conn); err != nil {
return err
}
return nil
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write socks5 auth: %w", err)
}
return nil
}
func (c *Client) socks5UserPassAuth(conn net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return fmt.Errorf("read socks5 auth header: %w", err)
}
if header[0] != 0x01 {
return fmt.Errorf("%w: expected auth version 1, got %d", ErrInvalidSOCKSVersion, header[0])
}
ulen := int(header[1])
userBuf := make([]byte, ulen)
if _, err := io.ReadFull(conn, userBuf); err != nil {
return fmt.Errorf("read socks5 username: %w", err)
}
plenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, plenBuf); err != nil {
return fmt.Errorf("read socks5 plen: %w", err)
}
plen := int(plenBuf[0])
passBuf := make([]byte, plen)
if _, err := io.ReadFull(conn, passBuf); err != nil {
return fmt.Errorf("read socks5 password: %w", err)
}
if string(userBuf) != c.socksUser || string(passBuf) != c.socksPass {
_, _ = conn.Write([]byte{0x01, 0x01})
return ErrSOCKSAuthFailed
}
if _, err := conn.Write([]byte{0x01, 0x00}); err != nil {
return fmt.Errorf("write socks5 auth success: %w", err)
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0, fmt.Errorf("read socks5 request: %w", err)
}
if header[1] != 1 {
return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1])
}
addr, err := c.readSocks5Addr(conn, header[3])
if err != nil {
return "", 0, err
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, fmt.Errorf("read socks5 port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 ipv4: %w", err)
}
return net.IP(buf).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read socks5 domain len: %w", err)
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 domain: %w", err)
}
return string(buf), nil
default:
return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType)
}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}