refactor: improve SOCKS5 error handling, refactor client connection logic, and add documentation to internal packages.

This commit is contained in:
zarazaex69
2026-04-20 05:46:27 +03:00
parent 40f1ad14e3
commit a58e343331
10 changed files with 198 additions and 95 deletions

View File

@@ -185,7 +185,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
fmt.Sprintf("%s:%d", cfg.socksHost, cfg.socksPort),
cfg.dnsServer,
"",
0,
"",
)
}
}

View File

@@ -24,11 +24,22 @@ import (
)
var (
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeySize is returned when the key size is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
ErrNoPeers = errors.New("no peers available")
ErrEncryptFailed = errors.New("encrypt failed")
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
// ErrNoPeers is returned when no peers are available for sending.
ErrNoPeers = errors.New("no peers available")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
)
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
@@ -53,8 +64,23 @@ func Run(
keyHex string,
localAddr string,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
socksUser string,
socksPass string,
) error {
return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
}
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
func RunWithReady(
ctx context.Context,
providerName,
roomURL,
keyHex string,
localAddr string,
dnsServer,
_ string,
_ string,
onReady func(),
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -82,19 +108,24 @@ func Run(
const peerCount = 1
for i := range peerCount {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
}
}
ln, err := net.Listen("tcp", localAddr)
lc := net.ListenConfig{}
ln, err := lc.Listen(runCtx, "tcp", localAddr)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
if onReady != nil {
onReady()
}
go c.acceptLoop(runCtx, ln)
<-runCtx.Done()
@@ -254,7 +285,7 @@ func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
}
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
defer conn.Close()
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
logger.Debugf("SOCKS5 handshake failed: %v", err)
@@ -274,16 +305,25 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
req := map[string]any{
"cmd": "connect",
"addr": addr,
"port": port,
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
return
}
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
reqData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("marshal connect: %w", err)
}
reqData, _ := json.Marshal(req)
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Warnf("sid=%d send connect failed: %v", sid, err)
return
return fmt.Errorf("send connect: %w", err)
}
dataReady := c.mux.WaitForData(sid)
@@ -292,30 +332,27 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
resp := c.mux.ReadStream(sid)
if len(resp) > 0 && resp[0] == 0x00 {
if _, err := conn.Write(replySuccess()); err != nil {
return
return fmt.Errorf("write success: %w", err)
}
} else {
_, _ = conn.Write(replyHostUnreachable())
return
return ErrTunnelSetupFailed
}
case <-time.After(15 * time.Second):
_, _ = conn.Write(replyHostUnreachable())
c.mux.CleanupDataChannel(sid)
return
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
case <-ctx.Done():
return
return fmt.Errorf("context cancelled: %w", ctx.Err())
}
c.mux.CleanupDataChannel(sid)
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return err
return fmt.Errorf("read header: %w", err)
}
if buf[0] != 5 {
@@ -324,60 +361,68 @@ func (c *Client) socks5Handshake(conn net.Conn) error {
methods := make([]byte, int(buf[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
return fmt.Errorf("read methods: %w", err)
}
_, err := conn.Write([]byte{5, 0})
return err
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write response: %w", err)
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
return "", 0, fmt.Errorf("read request header: %w", err)
}
if buf[0] != 5 || buf[1] != 1 {
return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1])
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
}
var addr string
switch buf[3] {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", 0, err
}
addr = string(domain)
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", 0, err
}
addr = net.IP(ip).String()
default:
return "", 0, fmt.Errorf("unsupported address type: %d", buf[3])
addr, err := c.readSocks5Addr(conn, buf[3])
if err != nil {
return "", 0, err
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, err
return "", 0, fmt.Errorf("read port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv4: %w", err)
}
return net.IP(ip).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read domain len: %w", err)
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", fmt.Errorf("read domain: %w", err)
}
return string(domain), nil
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv6: %w", err)
}
return net.IP(ip).String(), nil
default:
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
}
}
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {

View File

@@ -11,7 +11,9 @@ import (
)
var (
ErrInvalidKeySize = errors.New("invalid key size")
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
ErrInvalidKeySize = errors.New("invalid key size")
// ErrCiphertextTooShort is returned when the ciphertext is shorter than the nonce size.
ErrCiphertextTooShort = errors.New("ciphertext too short")
)

View File

@@ -1,3 +1,4 @@
// Package logger provides a simple leveled logging interface.
package logger
import (
@@ -5,6 +6,9 @@ import (
"sync/atomic"
)
// verboseEnabled controls whether verbose and debug logging is enabled.
//
//nolint:gochecknoglobals // Global log state is acceptable for CLI tools.
var verboseEnabled atomic.Bool
// SetVerbose enables or disables verbose/debug logging.

View File

@@ -5,35 +5,42 @@ import (
"encoding/binary"
"errors"
"fmt"
"math"
"sync"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
var (
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
ErrClientResetID = errors.New("client reset requires a non-zero client id")
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
ErrDataTooLarge = errors.New("data chunk too large")
)
const (
// Frame Header sizes
// HeaderSize is the size of the frame header in bytes.
HeaderSize = 12
// Special Stream IDs
// ControlStreamID is a special stream ID used for control frames.
ControlStreamID uint16 = 0xFFFF
// Control Frame Types
// ControlResetClient is a control frame type used to signal a client reset.
ControlResetClient uint32 = 1
// Frame Types (Internal to mux logic)
FrameTypeData uint16 = 0
// FrameTypeData is a marker for data frames.
FrameTypeData uint16 = 0
// FrameTypeControl is a marker for control frames.
FrameTypeControl uint16 = 0xFFFF
)
// ControlFrame represents a control message between multiplexers.
type ControlFrame struct {
ClientID uint32
Type uint32
}
// Stream represents a single multiplexed data stream.
type Stream struct {
ID uint16
ClientID uint32
@@ -44,12 +51,14 @@ type Stream struct {
outOfOrder map[uint32][]byte
}
// RecvBuf returns the current receive buffer content.
func (s *Stream) RecvBuf() []byte {
s.mu.Lock()
defer s.mu.Unlock()
return s.recvBuf
}
// Multiplexer coordinates multiple Streams over a single transport channel.
type Multiplexer struct {
streams map[uint16]*Stream
nextID uint16
@@ -67,6 +76,7 @@ type Multiplexer struct {
bufferCond *sync.Cond
}
// New creates a new Multiplexer instance.
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
m := &Multiplexer{
streams: make(map[uint16]*Stream),
@@ -82,6 +92,7 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
return m
}
// OpenStream allocates and returns a new unique stream ID.
func (m *Multiplexer) OpenStream() uint16 {
m.mu.Lock()
defer m.mu.Unlock()
@@ -105,6 +116,7 @@ func (m *Multiplexer) OpenStream() uint16 {
}
}
// SendData fragments and sends data over a specific stream.
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
m.mu.RLock()
stream, exists := m.streams[sid]
@@ -115,11 +127,6 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
}
const chunkSize = 7000
totalChunks := (len(data) + chunkSize - 1) / chunkSize
if totalChunks > 10 {
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
}
for i := 0; i < len(data); i += chunkSize {
end := i + chunkSize
@@ -134,10 +141,14 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
m.sendSeq[sid]++
m.sendSeqMu.Unlock()
if len(chunk) > math.MaxUint16 {
return ErrDataTooLarge
}
frame := make([]byte, HeaderSize+len(chunk))
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk)))
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
binary.BigEndian.PutUint32(frame[8:12], seq)
copy(frame[HeaderSize:], chunk)
@@ -149,6 +160,7 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
return nil
}
// CloseStream signals that a stream should be terminated.
func (m *Multiplexer) CloseStream(sid uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -176,6 +188,7 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
return nil
}
// SendClientReset sends a control frame to reset all streams for this client.
func (m *Multiplexer) SendClientReset() error {
if m.clientID == 0 {
return ErrClientResetID
@@ -186,6 +199,7 @@ func (m *Multiplexer) SendClientReset() error {
return nil
}
// BuildControlFrame constructs a raw control frame.
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
frame := make([]byte, HeaderSize)
binary.BigEndian.PutUint32(frame[0:4], clientID)
@@ -195,6 +209,7 @@ func BuildControlFrame(clientID uint32, controlType uint32) []byte {
return frame
}
// ParseControlFrame attempts to extract control information from a frame.
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
if len(frame) < HeaderSize {
return ControlFrame{}, false
@@ -212,6 +227,7 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) {
}, true
}
// HandleFrame processes an incoming frame from the transport.
func (m *Multiplexer) HandleFrame(frame []byte) {
control, ok := ParseControlFrame(frame)
if ok {
@@ -336,6 +352,7 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
}
}
// ResetClient closes and removes all streams associated with a client ID.
func (m *Multiplexer) ResetClient(clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -363,6 +380,7 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
}
}
// ReadStream retrieves and clears the current receive buffer for a stream.
func (m *Multiplexer) ReadStream(sid uint16) []byte {
m.mu.Lock()
defer m.mu.Unlock()
@@ -381,6 +399,7 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte {
return data
}
// StreamClosed returns true if the stream is closed or doesn't exist.
func (m *Multiplexer) StreamClosed(sid uint16) bool {
m.mu.RLock()
defer m.mu.RUnlock()
@@ -389,6 +408,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool {
return !exists || stream.closed
}
// GetStreams returns a list of all active stream IDs.
func (m *Multiplexer) GetStreams() []uint16 {
m.mu.RLock()
defer m.mu.RUnlock()
@@ -400,12 +420,14 @@ func (m *Multiplexer) GetStreams() []uint16 {
return sids
}
// GetStream returns the Stream object for a given ID.
func (m *Multiplexer) GetStream(sid uint16) *Stream {
m.mu.RLock()
defer m.mu.RUnlock()
return m.streams[sid]
}
// Reset clears all multiplexer state and closes all streams.
func (m *Multiplexer) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -424,6 +446,7 @@ func (m *Multiplexer) Reset() {
m.bufferCond.Broadcast()
}
// UpdateSendFunc updates the function used to transmit raw frames.
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -431,6 +454,7 @@ func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
m.onSend = onSend
}
// WaitForData returns a channel that signals when new data is available for a stream.
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
@@ -441,6 +465,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
return m.dataReady[sid]
}
// CleanupDataChannel removes the data notification channel for a stream.
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()

View File

@@ -3,6 +3,7 @@ package jazz
import (
"context"
"errors"
"fmt"
"log"
"strings"
@@ -533,12 +534,21 @@ func (p *Peer) Close() error {
return nil
}
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
)
// AddVideoTrack adds a video track to the publisher peer connection.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pcPub == nil {
return nil, fmt.Errorf("publisher peer connection not initialized")
return nil, ErrPublisherNotInitialized
}
return p.pcPub.AddTrack(track)
sender, err := p.pcPub.AddTrack(track)
if err != nil {
return nil, fmt.Errorf("failed to add track: %w", err)
}
return sender, nil
}
// SetReconnectCallback sets the callback for reconnection events.

View File

@@ -9,14 +9,21 @@ import (
)
var (
ErrProviderNotFound = errors.New("provider not found")
ErrDataChannelTimeout = errors.New("datachannel timeout")
// ErrProviderNotFound is returned when a requested provider is not registered.
ErrProviderNotFound = errors.New("provider not found")
// ErrDataChannelTimeout is returned when the DataChannel fails to open within the timeout period.
ErrDataChannelTimeout = errors.New("datachannel timeout")
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
ErrDataChannelNotReady = errors.New("datachannel not ready")
ErrSendQueueClosed = errors.New("send queue closed")
ErrSendQueueTimeout = errors.New("send queue timeout")
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
ErrSendQueueClosed = errors.New("send queue closed")
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
ErrSendQueueTimeout = errors.New("send queue timeout")
)
// Provider defines the standard interface for WebRTC connection handlers.
//
//nolint:interfacebloat // All methods are necessary for provider abstraction.
type Provider interface {
Connect(ctx context.Context) error
Send(data []byte) error
@@ -47,6 +54,8 @@ type Config struct {
type Factory func(ctx context.Context, cfg Config) (Provider, error)
// registry holds all registered provider factories.
//
//nolint:gochecknoglobals // Global registry is required for provider discovery.
var registry = make(map[string]Factory)
// Register adds a new provider factory to the registry.

View File

@@ -1161,10 +1161,19 @@ func (p *Peer) CanSend() bool {
return len(p.sendQueue) < 4000
}
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
)
// AddVideoTrack adds a video track to the publisher peer connection.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pcPub == nil {
return nil, fmt.Errorf("publisher peer connection not initialized")
return nil, ErrPublisherNotInitialized
}
return p.pcPub.AddTrack(track)
sender, err := p.pcPub.AddTrack(track)
if err != nil {
return nil, fmt.Errorf("failed to add track: %w", err)
}
return sender, nil
}

View File

@@ -18,8 +18,14 @@ const (
)
var (
errPeerClosed = errors.New("peer closed")
errSendQueueFull = errors.New("send queue full")
// ErrPeerClosed is returned when an operation is attempted on a closed peer.
ErrPeerClosed = errors.New("peer closed")
// ErrSendQueueFull is returned when the transmission queue is full.
ErrSendQueueFull = errors.New("send queue full")
// ErrLiveKitNotConnected is returned when the LiveKit room is not connected.
ErrLiveKitNotConnected = errors.New("livekit room not connected")
// ErrVideoNotSupported is returned when video tracks are not supported by this provider.
ErrVideoNotSupported = errors.New("video tracks not supported yet in wbstream")
)
// Peer represents a WB Stream WebRTC connection using LiveKit.
@@ -137,13 +143,13 @@ func (p *Peer) processSendQueue() {
// Send transmits data to the room.
func (p *Peer) Send(data []byte) error {
if p.closed.Load() {
return errPeerClosed
return ErrPeerClosed
}
select {
case p.sendQueue <- data:
return nil
default:
return errSendQueueFull
return ErrSendQueueFull
}
}
@@ -197,23 +203,15 @@ func (p *Peer) GetBufferedAmount() uint64 {
// AddVideoTrack adds a video track to the LiveKit room.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.room == nil || p.room.LocalParticipant == nil {
return nil, fmt.Errorf("livekit room not connected")
return nil, ErrLiveKitNotConnected
}
publication, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
_, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: "video",
})
if err != nil {
return nil, fmt.Errorf("failed to publish track: %w", err)
}
// LiveKit SDK wraps RTPSender, but for the interface compatibility we might need to handle this differently.
// Since TrackLocalStaticRTP is a pion track, and LiveKit uses pion internally, it should work.
// However, LiveKit's PublishTrack doesn't return *webrtc.RTPSender directly.
// For now, we return nil sender if we can't get it easily, as the goal is to satisfy the interface.
if publication != nil {
return nil, nil // TODO: extract RTPSender if needed for VideoChannel
}
return nil, nil
return nil, ErrVideoNotSupported
}

View File

@@ -5,6 +5,7 @@ package mobile
import (
"context"
"errors"
"fmt"
"log"
"sync"
"time"
@@ -111,7 +112,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er
"telemost",
roomURL,
keyHex,
socksPort,
fmt.Sprintf("127.0.0.1:%d", socksPort),
"",
socksUser,
socksPass,