Files
olcrtc/internal/client/client.go

799 lines
22 KiB
Go

// Package client implements the local SOCKS5 client side of the olcrtc tunnel.
package client
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/runtime"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
var (
// ErrConnectFailed is returned when a tunnel connection fails.
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrKeySize is returned when the encryption key is not 32 bytes.
// Re-exported from runtime for compatibility with errors.Is callers.
ErrKeySize = runtime.ErrKeySize
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command")
// ErrUnsupportedAddressType is returned for unsupported SOCKS address types.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrRemoteNotReady is returned when the server-side stream fails to signal readiness.
ErrRemoteNotReady = errors.New("remote not ready")
// ErrSOCKSAuthFailed is returned when username/password authentication is rejected.
ErrSOCKSAuthFailed = errors.New("SOCKS5 authentication failed")
// ErrSOCKSCredTooLong is returned when a SOCKS5 username or password exceeds 255 bytes.
ErrSOCKSCredTooLong = errors.New("socks5 user/pass exceeds 255 bytes")
)
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
ln transport.Transport
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
reconnectMu sync.Mutex
health *runtime.HealthTracker
deviceID string
sessionID string
claims map[string]any
dnsServer string
socksUser string
socksPass string
}
// HealthFunc is called when the client control health snapshot changes.
type HealthFunc func(control.Status)
// Config holds runtime configuration for [Run] and [RunWithReady].
type Config struct {
Transport string
Carrier string
RoomURL string
ChannelID string
KeyHex string
LocalAddr string
DNSServer string
SOCKSUser string
SOCKSPass string
TransportOptions transport.Options
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
// are empty).
DeviceID string
// DeviceIDPath is a file in which to persist the auto-generated device ID
// across restarts. Ignored when DeviceID is set explicitly.
DeviceIDPath string
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
// OnHealth receives liveness/reconnect status updates. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the client with the given configuration.
func Run(ctx context.Context, cfg Config) error {
return RunWithReady(ctx, cfg, nil)
}
// RunWithReady is like Run but invokes onReady once the local SOCKS listener is up.
func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(cfg.KeyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
deviceID, err := resolveDeviceID(cfg.DeviceID, cfg.DeviceIDPath)
if err != nil {
return fmt.Errorf("resolve device id: %w", err)
}
c := &Client{
cipher: cipher,
deviceID: deviceID,
claims: cfg.Claims,
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
health: runtime.NewHealthTracker(cfg.OnHealth),
}
// shutdown is registered BEFORE bringUpLink so we always close any
// link/session that bringUpLink managed to set up before it
// errored out. The previous ordering returned early on failure
// (e.g. handshake timeout against a wedged seichannel transport)
// without ever calling Close on the carrier link, leaving our MUC
// presence behind as a ghost participant in the next test that
// joined the same room. shutdown is nil-safe - it skips fields
// that bringUpLink hadn't populated yet.
defer c.shutdown()
if err := c.bringUpLink(runCtx, cfg, cancel); err != nil {
return err
}
lc := net.ListenConfig{}
listener, err := lc.Listen(runCtx, "tcp4", cfg.LocalAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", cfg.LocalAddr, err)
}
defer func() { _ = listener.Close() }()
logger.Infof("SOCKS5 server listening on %s", cfg.LocalAddr)
if onReady != nil {
onReady()
}
go c.acceptLoop(runCtx, listener)
<-runCtx.Done()
return nil
}
func (c *Client) bringUpLink(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
) error {
ln, err := transport.New(ctx, cfg.Transport, transport.Config{
Carrier: cfg.Carrier,
RoomURL: cfg.RoomURL,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ChannelID: cfg.ChannelID,
DeviceID: c.deviceID,
Name: names.Generate(),
OnData: c.onData,
DNSServer: cfg.DNSServer,
Options: cfg.TransportOptions,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
c.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Client link reported conference end: %s", reason)
cancel()
})
ln.SetShouldReconnect(func() bool { return ctx.Err() == nil })
ln.SetReconnectCallback(func() {
if ctx.Err() != nil {
return
}
// Carrier callback fires after the link is back up. If handshake
// still fails it usually means the server hasn't completed its
// own reinstall yet - keep the listener up and wait for either
// another callback or a future liveness loss to re-trigger.
c.handleReconnect(ctx, cfg, cancel, "carrier")
})
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln)))
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
control, sid, err := openControlStream(ctx, sess, c.deviceID, c.claims)
if err != nil {
_ = sess.Close()
_ = c.conn.Close()
return fmt.Errorf("handshake: %w", err)
}
logger.Infof("session %s opened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session and carries
// post-handshake control messages.
func openControlStream(
ctx context.Context,
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
return openControlStreamTimeout(ctx, sess, deviceID, claims, handshake.DefaultTimeout)
}
func openControlStreamTimeout(
ctx context.Context,
sess *smux.Session,
deviceID string,
claims map[string]any,
timeout time.Duration,
) (*smux.Stream, string, error) {
stream, err := sess.OpenStream()
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = stream.Close()
case <-done:
}
}()
defer close(done)
_ = stream.SetDeadline(time.Now().Add(timeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
_ = stream.Close()
if ctx.Err() != nil {
return nil, "", fmt.Errorf("handshake client: %w", ctx.Err())
}
return nil, "", fmt.Errorf("handshake client: %w", err)
}
return stream, sid, nil
}
// resolveDeviceID returns the device ID to send in CLIENT_HELLO.
//
// Precedence:
// 1. Explicit deviceID arg (Config.DeviceID) - used verbatim.
// 2. Persistent file at path (Config.DeviceIDPath) - read if it exists,
// otherwise generated and written for future runs.
// 3. Random UUID per run when both inputs are empty.
func resolveDeviceID(deviceID, path string) (string, error) {
if deviceID != "" {
return deviceID, nil
}
if path == "" {
return uuid.NewString(), nil
}
// #nosec G304 -- persistent device ID path is explicit user configuration.
data, err := os.ReadFile(path)
if err == nil {
id := strings.TrimSpace(string(data))
if id != "" {
return id, nil
}
} else if !errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("read device id %s: %w", path, err)
}
id := uuid.NewString()
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
return "", fmt.Errorf("mkdir device id dir: %w", err)
}
if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil {
return "", fmt.Errorf("write device id %s: %w", path, err)
}
return id, nil
}
func smuxConfig(maxWirePayload int) *smux.Config {
return runtime.SmuxConfig(maxWirePayload)
}
func linkMaxPayload(tr transport.Transport) int {
return runtime.MaxPayload(tr)
}
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) {
c.reconnectMu.Lock()
defer c.reconnectMu.Unlock()
c.recordReconnect()
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
c.resetLinkPeer()
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
// again with its own conn on each attempt.
newConn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
oldControl := c.controlStrm
oldControlStop := c.controlStop
oldSess := c.session
oldConn := c.conn
c.conn = newConn
c.session = nil
c.controlStrm = nil
c.controlStop = nil
c.sessionID = ""
c.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
if oldControl != nil {
_ = oldControl.Close()
}
// When liveness on top of a still-"connected" carrier expires, the
// underlying ICE/data path has gone silent without the engine noticing.
// Re-handshaking over the dead carrier just times out repeatedly, so
// ask the carrier to rebuild itself; the new carrier will fire its own
// reconnect callback which then drives a fresh handshake.
if reason == "liveness" && c.ln != nil {
c.ln.Reconnect("liveness")
}
c.retryHandshake(ctx, cfg, cancel, reason)
}
func (c *Client) retryHandshake(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) {
const (
initialDelay = 300 * time.Millisecond
maxDelay = 5 * time.Second
)
delay := initialDelay
for attempt := 1; ; attempt++ {
if ctx.Err() != nil {
return
}
logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason)
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
return
}
// Don't fail the whole process on liveness reconnect: the carrier
// rebuild may take dozens of seconds (e.g. ICE restart on a flaky
// network). Keep the SOCKS5 listener open and wait - handleSocks5
// will return host-unreachable to clients until we recover. For
// carrier-driven reconnects the callback fires after the link is
// already up, so a missed handshake is more suspicious; cap it.
if reason == "carrier" && attempt >= 5 {
logger.Warnf("client reconnect: exhausted %d handshake attempts (reason=%s) - keeping listener up", attempt, reason)
return
}
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
if delay < maxDelay {
delay *= 2
if delay > maxDelay {
delay = maxDelay
}
}
}
}
func (c *Client) resetLinkPeer() {
c.sessMu.RLock()
ln := c.ln
c.sessMu.RUnlock()
if resetter, ok := ln.(interface{ ResetPeer() }); ok {
resetter.ResetPeer()
}
}
func (c *Client) tryReopenSession(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
attempt int,
) bool {
conn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
old := c.conn
c.conn = conn
c.sessMu.Unlock()
if old != nil {
_ = old.Close()
}
sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln)))
if err != nil {
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
}
control, sid, err := openControlStreamTimeout(ctx, sess, c.deviceID, c.claims, 2*time.Second)
if err != nil {
logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err)
_ = sess.Close()
return false
}
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
c.session = sess
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
return true
}
func (c *Client) startControlLoop(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
stream *smux.Stream,
) {
controlCtx, stop := context.WithCancel(ctx)
c.sessMu.Lock()
c.controlStop = stop
c.sessMu.Unlock()
liveness := cfg.Liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
c.sessMu.RLock()
sid := c.sessionID
c.sessMu.RUnlock()
c.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
c.recordMissed(missed)
logger.Warnf("control missed pong on client: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
c.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
go func() {
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("client control stream ended: %v", err)
}
// handleReconnect now retries indefinitely on liveness so it only
// returns false on ctx cancellation; don't tear down the client.
c.handleReconnect(ctx, cfg, cancel, "liveness")
}()
}
// Status returns the latest client-side control health snapshot.
func (c *Client) Status() control.Status {
return c.health.Status()
}
func (c *Client) recordSession(sessionID string) { c.health.RecordSession(sessionID) }
func (c *Client) recordPong(h control.Health) { c.health.RecordPong(h) }
func (c *Client) recordMissed(missed int) { c.health.RecordMissed(missed) }
func (c *Client) recordUnhealthy(missed int) { c.health.RecordUnhealthy(missed) }
func (c *Client) recordReconnect() { c.health.RecordReconnect() }
func (c *Client) shutdown() {
c.sessMu.Lock()
control := c.controlStrm
controlStop := c.controlStop
sess := c.session
conn := c.conn
c.controlStrm = nil
c.controlStop = nil
c.session = nil
c.conn = nil
c.sessMu.Unlock()
notifyControlClose(control)
if controlStop != nil {
controlStop()
}
if sess != nil {
_ = sess.Close()
}
if conn != nil {
_ = conn.Close()
}
if c.ln != nil {
_ = c.ln.Close()
}
if control != nil {
_ = control.Close()
}
}
func notifyControlClose(stream *smux.Stream) {
if stream == nil {
return
}
_ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := control.SendClose(stream); err == nil {
time.Sleep(200 * time.Millisecond)
}
_ = stream.SetWriteDeadline(time.Time{})
_ = stream.CloseWrite()
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
cipher, err := runtime.SetupCipher(keyHex)
if err != nil {
return nil, fmt.Errorf("client: %w", err)
}
return cipher, nil
}
func (c *Client) onData(data []byte) {
c.sessMu.RLock()
conn := c.conn
c.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
logger.Warnf("Accept error: %v", err)
continue
}
}
go c.handleSocks5(ctx, conn)
}
}
func (c *Client) handleSocks5(_ context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
return
}
targetAddr, targetPort, err := c.socks5Request(conn)
if err != nil {
return
}
c.sessMu.RLock()
sess := c.session
c.sessMu.RUnlock()
if sess == nil || sess.IsClosed() {
_, _ = conn.Write(replyHostUnreachable())
return
}
c.tunnel(conn, sess, targetAddr, targetPort)
}
func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) {
stream, err := sess.OpenStream()
if err != nil {
logger.Warnf("OpenStream failed: %v", err)
_, _ = conn.Write(replyHostUnreachable())
return
}
defer func() { _ = stream.Close() }()
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil {
logger.Warnf("sid=%d connect failed: %v", stream.ID(), err)
_, _ = conn.Write(replyHostUnreachable())
return
}
if _, err := conn.Write(replySuccess()); err != nil {
return
}
go func() {
_, _ = io.Copy(stream, conn)
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
}
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
connectReq, err := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err != nil {
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Time{})
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack)
}
_ = stream.SetReadDeadline(time.Time{})
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("read socks5 header: %w", err)
}
if buf[0] != 5 {
return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0])
}
methods := make([]byte, buf[1])
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("read socks5 methods: %w", err)
}
if c.socksUser != "" {
// RFC 1929: method 0x02 = username/password auth.
if _, err := conn.Write([]byte{5, 2}); err != nil {
return fmt.Errorf("write socks5 auth method: %w", err)
}
if err := c.socks5UserPassAuth(conn); err != nil {
return err
}
return nil
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write socks5 auth: %w", err)
}
return nil
}
func (c *Client) socks5UserPassAuth(conn net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return fmt.Errorf("read socks5 auth header: %w", err)
}
if header[0] != 0x01 {
return fmt.Errorf("%w: expected auth version 1, got %d", ErrInvalidSOCKSVersion, header[0])
}
ulen := int(header[1])
userBuf := make([]byte, ulen)
if _, err := io.ReadFull(conn, userBuf); err != nil {
return fmt.Errorf("read socks5 username: %w", err)
}
plenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, plenBuf); err != nil {
return fmt.Errorf("read socks5 plen: %w", err)
}
plen := int(plenBuf[0])
passBuf := make([]byte, plen)
if _, err := io.ReadFull(conn, passBuf); err != nil {
return fmt.Errorf("read socks5 password: %w", err)
}
if string(userBuf) != c.socksUser || string(passBuf) != c.socksPass {
_, _ = conn.Write([]byte{0x01, 0x01})
return ErrSOCKSAuthFailed
}
if _, err := conn.Write([]byte{0x01, 0x00}); err != nil {
return fmt.Errorf("write socks5 auth success: %w", err)
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0, fmt.Errorf("read socks5 request: %w", err)
}
if header[1] != 1 {
return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1])
}
addr, err := c.readSocks5Addr(conn, header[3])
if err != nil {
return "", 0, err
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, fmt.Errorf("read socks5 port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 ipv4: %w", err)
}
return net.IP(buf).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read socks5 domain len: %w", err)
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 domain: %w", err)
}
return string(buf), nil
default:
return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType)
}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}