feat: refactor client connection handling and error management

This commit is contained in:
zarazaex69
2026-04-21 01:51:48 +03:00
parent 97b12eb2d8
commit 513e2bdd9d
4 changed files with 177 additions and 249 deletions

View File

@@ -12,7 +12,6 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
@@ -23,36 +22,25 @@ import (
)
var (
// ErrKeySize is returned when the key size is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
// ErrNoLinks is returned when no links are available for sending.
ErrNoLinks = errors.New("no links available")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
// ErrConnectFailed is returned when a tunnel connection fails.
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrMuxExited is returned when the multiplexer loop exits unexpectedly.
ErrMuxExited = errors.New("multiplexer loop exited")
// ErrNoAvailableLinks is returned when no links are ready for sending.
ErrNoAvailableLinks = errors.New("no available links")
)
// Client handles local SOCKS5 connections and tunnels them through the selected runtime stack.
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
links []link.Link
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
linkIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
links []link.Link
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
clientID uint32
dnsServer string
}
// Run starts the client with the specified parameters.
@@ -71,8 +59,9 @@ func Run(
videoHeight int,
videoFPS int,
videoBitrate string,
videoHW string,
) error {
return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate)
return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW)
}
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
@@ -92,6 +81,7 @@ func RunWithReady(
videoHeight int,
videoFPS int,
videoBitrate string,
videoHW string,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -119,17 +109,17 @@ func RunWithReady(
const linkCount = 1
for i := range linkCount {
if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate); err != nil {
if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW); err != nil {
return fmt.Errorf("addLink failed: %w", err)
}
}
lc := net.ListenConfig{}
ln, err := lc.Listen(runCtx, "tcp", localAddr)
ln, err := lc.Listen(runCtx, "tcp4", localAddr)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
}
defer func() { _ = ln.Close() }()
defer ln.Close()
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
@@ -137,13 +127,17 @@ func RunWithReady(
onReady()
}
go c.acceptLoop(runCtx, ln)
errCh := make(chan error, 1)
go func() {
errCh <- c.acceptLoop(runCtx, ln)
}()
<-runCtx.Done()
c.shutdown()
c.wg.Wait()
return nil
select {
case <-runCtx.Done():
return nil
case err := <-errCh:
return err
}
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
@@ -152,15 +146,10 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, ErrKeySize
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, ErrKeyStringLength
}
cipher, err := crypto.NewCipher(keyStr)
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
@@ -185,13 +174,12 @@ func (c *Client) setupMux() {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
return err
}
if len(c.links) == 0 {
return ErrNoLinks
return ErrNoAvailableLinks
}
idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec
return c.links[idx].Send(encrypted)
return c.links[0].Send(encrypted)
})
}
@@ -207,7 +195,7 @@ func (c *Client) addLink(
socksProxyAddr string,
socksProxyPort int,
videoWidth, videoHeight, videoFPS int,
videoBitrate string,
videoBitrate, videoHW string,
) error {
ln, err := link.New(ctx, linkName, link.Config{
Transport: transportName,
@@ -222,6 +210,7 @@ func (c *Client) addLink(
VideoHeight: videoHeight,
VideoFPS: videoFPS,
VideoBitrate: videoBitrate,
VideoHW: videoHW,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
@@ -237,25 +226,17 @@ func (c *Client) addLink(
c.handleLinkReconnect(linkID)
})
logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName)
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
logger.Infof("Link %d connected", linkID)
c.wg.Add(1)
go func() {
defer c.wg.Done()
ln.WatchConnection(ctx)
}()
c.sendClientResetAsync("initial")
go ln.WatchConnection(ctx)
return nil
}
func (c *Client) handleLinkReconnect(linkID int) {
logger.Infof("link %d reconnect event", linkID)
c.sendResetSignal()
c.connMu.Lock()
for sid, conn := range c.connections {
@@ -269,228 +250,176 @@ func (c *Client) handleLinkReconnect(linkID int) {
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
return err
}
if len(c.links) == 0 {
return ErrNoLinks
return ErrNoAvailableLinks
}
idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec
return c.links[idx].Send(encrypted)
return c.links[0].Send(encrypted)
})
c.mux.Reset()
c.sendClientResetAsync("reconnect")
}
func (c *Client) sendClientResetAsync(source string) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after %s: %v", source, err)
}
}()
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
select {
case <-ctx.Done():
return
default:
conn, err := ln.Accept()
if err != nil {
logger.Debugf("Accept error: %v", err)
continue
}
go c.handleSOCKS5(ctx, conn)
}
func (c *Client) sendResetSignal() {
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
encrypted, _ := c.cipher.Encrypt(resetFrame)
if len(c.links) > 0 {
_ = c.links[0].Send(encrypted)
}
}
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
logger.Warnf("Accept error: %v", err)
continue
}
}
go c.handleSocks5(ctx, conn)
}
}
func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
defer conn.Close()
if err := c.socks5Handshake(conn); err != nil {
logger.Debugf("SOCKS5 handshake failed: %v", err)
return
}
addr, port, err := c.socks5Request(conn)
targetAddr, targetPort, err := c.socks5Request(conn)
if err != nil {
logger.Debugf("SOCKS5 request failed: %v", err)
return
}
sid := c.mux.OpenStream()
defer c.mux.CloseStream(sid)
c.connMu.Lock()
c.connections[sid] = conn
c.connMu.Unlock()
defer func() {
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
logger.Infof("sid=%d tunnel to %s:%d", sid, targetAddr, targetPort)
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
connectReq, _ := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err := c.mux.SendData(sid, connectReq); err != nil {
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
_, _ = conn.Write(replyHostUnreachable())
return
}
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
reqData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("marshal connect: %w", err)
}
if err := c.mux.SendData(sid, reqData); err != nil {
return fmt.Errorf("send connect: %w", err)
}
readyTimer := time.NewTimer(10 * time.Second)
defer readyTimer.Stop()
dataReady := c.mux.WaitForData(sid)
var initialData []byte
select {
case <-dataReady:
resp := c.mux.ReadStream(sid)
if len(resp) > 0 && resp[0] == 0x00 {
if _, err := conn.Write(replySuccess()); err != nil {
return fmt.Errorf("write success: %w", err)
}
} else {
_, _ = conn.Write(replyHostUnreachable())
return ErrTunnelSetupFailed
}
case <-time.After(15 * time.Second):
case <-readyTimer.C:
logger.Warnf("sid=%d tunnel setup failed: timeout waiting for remote ready", sid)
_, _ = conn.Write(replyHostUnreachable())
c.mux.CleanupDataChannel(sid)
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
case <-ctx.Done():
return fmt.Errorf("context cancelled: %w", ctx.Err())
return
case <-dataReady:
initialData = c.mux.ReadStream(sid)
if len(initialData) == 0 || initialData[0] != 0x00 {
logger.Warnf("sid=%d tunnel setup failed: invalid remote ready", sid)
_, _ = conn.Write(replyHostUnreachable())
return
}
}
c.mux.CleanupDataChannel(sid)
return nil
if _, err := conn.Write(replySuccess()); err != nil {
return
}
// Handle the rest of initialData if any (unlikely for 0x00 packet)
if len(initialData) > 1 {
if _, err := conn.Write(initialData[1:]); err != nil {
return
}
}
go c.pumpFromMux(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("read header: %w", err)
return err
}
if buf[0] != 5 {
return ErrInvalidSocks5
return fmt.Errorf("invalid socks version: %d", buf[0])
}
methods := make([]byte, int(buf[1]))
methods := make([]byte, buf[1])
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("read methods: %w", err)
return err
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write response: %w", err)
return err
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, fmt.Errorf("read request header: %w", err)
}
if buf[0] != 5 || buf[1] != 1 {
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
}
addr, err := c.readSocks5Addr(conn, buf[3])
if err != nil {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0, err
}
if header[1] != 1 {
return "", 0, fmt.Errorf("unsupported socks command: %d", header[1])
}
var addr string
switch header[3] {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
addr = net.IP(buf).String()
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
addr = string(buf)
default:
return "", 0, fmt.Errorf("unsupported address type: %d", header[3])
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, fmt.Errorf("read port: %w", err)
return "", 0, err
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv4: %w", err)
}
return net.IP(ip).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read domain len: %w", err)
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", fmt.Errorf("read domain: %w", err)
}
return string(domain), nil
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv6: %w", err)
}
return net.IP(ip).String(), nil
default:
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
}
}
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {
logger.Debugf("Decrypt error: %v", err)
return
}
c.mux.HandleFrame(plaintext)
}
func (c *Client) shutdown() {
c.connMu.Lock()
for _, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
}
c.connMu.Unlock()
for i, tr := range c.links {
logger.Infof("closing link %d", i)
_ = tr.Close()
}
}
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
c.activeClients.Add(-1)
_ = c.mux.CloseStream(sid)
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
@@ -501,41 +430,36 @@ func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
}
}
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
lastLog = time.Now()
func (c *Client) pumpFromMux(ctx context.Context, sid uint16, conn net.Conn) {
defer c.mux.CleanupDataChannel(sid)
dataReady := c.mux.WaitForData(sid)
for {
select {
case <-ctx.Done():
return
case <-dataReady:
data := c.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
data := c.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = c.mux.CloseStream(sid)
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}()
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {
return
}
c.mux.HandleFrame(plaintext)
}
func (c *Client) canSendData() bool {

View File

@@ -37,6 +37,7 @@ type Config struct {
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
}
// Factory creates a link instance.

View File

@@ -80,6 +80,7 @@ func Run(
videoHeight int,
videoFPS int,
videoBitrate string,
videoHW string,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()

View File

@@ -61,6 +61,7 @@ type streamTransport struct {
videoH int
videoFPS int
videoBitrate string
videoHW string
}
// New creates a visual videochannel transport backed by a carrier-specific provider.
@@ -109,6 +110,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
videoH: cfg.VideoHeight,
videoFPS: cfg.VideoFPS,
videoBitrate: cfg.VideoBitrate,
videoHW: cfg.VideoHW,
}
if err := stream.AddTrack(track); err != nil {
@@ -124,7 +126,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout)
defer cancel()
encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate)
encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW)
if err != nil {
return err
}
@@ -328,7 +330,7 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc
return
}
decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS)
decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS, p.videoHW)
if err != nil {
logger.Warnf("videochannel decoder init failed: %v", err)
return