Merge branch 'transport/smux' into transport/videochannel

This commit is contained in:
zarazaex69
2026-05-03 01:24:39 +03:00
6 changed files with 483 additions and 1144 deletions

1
go.mod
View File

@@ -68,6 +68,7 @@ require (
github.com/tjfoc/gmsm v1.4.1 // indirect
github.com/twitchtv/twirp v8.1.3+incompatible // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/xtaci/smux v1.5.57 // indirect
github.com/zeebo/xxh3 v1.1.0 // indirect
go.opentelemetry.io/otel v1.40.0 // indirect
go.uber.org/atomic v1.11.0 // indirect

2
go.sum
View File

@@ -226,6 +226,8 @@ github.com/xtaci/kcp-go/v5 v5.6.72 h1:FLaQPalgpufJYQRk0OK+gErEhXGLUPjv6FSRPrFR8L
github.com/xtaci/kcp-go/v5 v5.6.72/go.mod h1:9O3D8WR+cyyUjGiTILYfg17vn72otWuXK2AFfqIe6CM=
github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM=
github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE=
github.com/xtaci/smux v1.5.57 h1:N72VbGoSYxgcm6mPOYX0QzEZNVD3UI/JlVvAtXF+WrY=
github.com/xtaci/smux v1.5.57/go.mod h1:IGQ9QYrBphmb/4aTnLEcJby0TNr3NV+OslIOMrX825Q=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89 h1:ytA0RfQZTYfjqFA9lBJMX1DTnXpTuKg0nf4udgdpunE=
github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89/go.mod h1:OUqzZNoXsg+ccaiAnSe0t4f8qc0W/cFx6io0lWsE1Gw=

View File

@@ -3,7 +3,6 @@ package client
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
@@ -17,8 +16,9 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/xtaci/smux"
)
var (
@@ -26,21 +26,16 @@ var (
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrMuxExited is returned when the multiplexer loop exits unexpectedly.
ErrMuxExited = errors.New("multiplexer loop exited")
// ErrNoAvailableLinks is returned when no links are ready for sending.
ErrNoAvailableLinks = errors.New("no available links")
)
// Client handles local SOCKS5 connections and tunnels them to the server.
type Client struct {
links []link.Link
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
clientID uint32
dnsServer string
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
sessMu sync.RWMutex
dnsServer string
}
// Run starts the client with the specified parameters.
@@ -105,37 +100,27 @@ func RunWithReady(
return fmt.Errorf("setupCipher failed: %w", err)
}
clientIDBytes := make([]byte, 4)
if _, err := rand.Read(clientIDBytes); err != nil {
return fmt.Errorf("failed to generate client ID: %w", err)
}
clientID := binary.BigEndian.Uint32(clientIDBytes)
c := &Client{cipher: cipher, dnsServer: dnsServer}
c := &Client{
cipher: cipher,
connections: make(map[uint16]net.Conn),
links: make([]link.Link, 0),
clientID: clientID,
dnsServer: dnsServer,
}
c.setupMux()
const linkCount = 1
for i := range linkCount {
if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil {
return fmt.Errorf("addLink failed: %w", err)
}
if err := c.bringUpLink(
runCtx, linkName, transportName, carrierName, roomURL, cancel,
dnsServer, "", 0,
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
vp8FPS, vp8BatchSize,
); err != nil {
return err
}
defer c.shutdown()
lc := net.ListenConfig{}
ln, err := lc.Listen(runCtx, "tcp4", localAddr)
listener, err := lc.Listen(runCtx, "tcp4", localAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
}
defer ln.Close()
defer listener.Close()
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
logger.Infof("SOCKS5 server listening on %s", localAddr)
if onReady != nil {
onReady()
@@ -143,96 +128,30 @@ func RunWithReady(
errCh := make(chan error, 1)
go func() {
errCh <- c.acceptLoop(runCtx, ln)
errCh <- c.acceptLoop(runCtx, listener)
}()
select {
case <-runCtx.Done():
c.shutdown()
return nil
case err := <-errCh:
return err
}
}
func (c *Client) shutdown() {
c.connMu.Lock()
for _, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
}
c.connMu.Unlock()
for i, ln := range c.links {
logger.Infof("closing link %d", i)
_ = ln.Close()
}
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (c *Client) setupMux() {
c.mux = mux.New(c.clientID, func(frame []byte) error {
for {
canSend := true
for _, ln := range c.links {
if !ln.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return err
}
if len(c.links) == 0 {
return ErrNoAvailableLinks
}
return c.links[0].Send(encrypted)
})
}
func (c *Client) addLink(
func (c *Client) bringUpLink(
ctx context.Context,
linkName,
transportName,
carrierName,
roomURL string,
linkID int,
linkName, transportName, carrierName, roomURL string,
cancel context.CancelFunc,
dnsServer,
socksProxyAddr string,
dnsServer, socksProxyAddr string,
socksProxyPort int,
videoWidth, videoHeight, videoFPS int,
videoBitrate, videoHW string,
videoQRSize int,
videoQRRecovery string,
videoCodec string,
videoTileModule int,
videoTileRS int,
vp8FPS int,
vp8BatchSize int,
videoTileModule, videoTileRS int,
vp8FPS, vp8BatchSize int,
) error {
ln, err := link.New(ctx, linkName, link.Config{
Transport: transportName,
@@ -259,56 +178,104 @@ func (c *Client) addLink(
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
c.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Client link %d reported conference end: %s", linkID, reason)
logger.Infof("Client link reported conference end: %s", reason)
cancel()
})
c.links = append(c.links, ln)
ln.SetReconnectCallback(func() {
c.handleLinkReconnect(linkID)
})
ln.SetReconnectCallback(func() { c.handleReconnect() })
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
c.sessMu.Lock()
c.session = sess
c.sessMu.Unlock()
go ln.WatchConnection(ctx)
return nil
}
func (c *Client) handleLinkReconnect(linkID int) {
logger.Infof("link %d reconnect event", linkID)
c.sendResetSignal()
c.connMu.Lock()
for sid, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
delete(c.connections, sid)
}
c.connMu.Unlock()
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return err
}
if len(c.links) == 0 {
return ErrNoAvailableLinks
}
return c.links[0].Send(encrypted)
})
c.mux.Reset()
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.MaxFrameSize = 32768
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
cfg.KeepAliveTimeout = 60 * time.Second
return cfg
}
func (c *Client) sendResetSignal() {
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
encrypted, _ := c.cipher.Encrypt(resetFrame)
if len(c.links) > 0 {
_ = c.links[0].Send(encrypted)
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect — tearing down smux session")
c.sessMu.Lock()
if c.session != nil {
_ = c.session.Close()
c.session = nil
}
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
c.sessMu.Unlock()
// New SOCKS5 connections will fail until the link comes back up; the
// caller will reissue them. Existing streams die with the smux session.
c.conn = muxconn.New(c.ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
logger.Warnf("smux re-init failed: %v", err)
return
}
c.sessMu.Lock()
c.session = sess
c.sessMu.Unlock()
}
func (c *Client) shutdown() {
c.sessMu.Lock()
if c.session != nil {
_ = c.session.Close()
}
if c.conn != nil {
_ = c.conn.Close()
}
c.sessMu.Unlock()
if c.ln != nil {
_ = c.ln.Close()
}
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (c *Client) onData(data []byte) {
c.sessMu.RLock()
conn := c.conn
c.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
@@ -340,19 +307,23 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
return
}
sid := c.mux.OpenStream()
defer c.mux.CloseStream(sid)
c.sessMu.RLock()
sess := c.session
c.sessMu.RUnlock()
if sess == nil || sess.IsClosed() {
_, _ = conn.Write(replyHostUnreachable())
return
}
c.connMu.Lock()
c.connections[sid] = conn
c.connMu.Unlock()
defer func() {
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
stream, err := sess.OpenStream()
if err != nil {
logger.Warnf("OpenStream failed: %v", err)
_, _ = conn.Write(replyHostUnreachable())
return
}
defer stream.Close()
logger.Infof("sid=%d tunnel to %s:%d", sid, targetAddr, targetPort)
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
connectReq, _ := json.Marshal(map[string]any{
"cmd": "connect",
@@ -360,45 +331,34 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
"port": targetPort,
})
if err := c.mux.SendData(sid, connectReq); err != nil {
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err)
_, _ = conn.Write(replyHostUnreachable())
return
}
_ = stream.SetWriteDeadline(time.Time{})
readyTimer := time.NewTimer(10 * time.Second)
defer readyTimer.Stop()
dataReady := c.mux.WaitForData(sid)
var initialData []byte
select {
case <-readyTimer.C:
logger.Warnf("sid=%d tunnel setup failed: timeout waiting for remote ready", sid)
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack)
_, _ = conn.Write(replyHostUnreachable())
return
case <-dataReady:
initialData = c.mux.ReadStream(sid)
if len(initialData) == 0 || initialData[0] != 0x00 {
logger.Warnf("sid=%d tunnel setup failed: invalid remote ready", sid)
_, _ = conn.Write(replyHostUnreachable())
return
}
}
_ = stream.SetReadDeadline(time.Time{})
if _, err := conn.Write(replySuccess()); err != nil {
return
}
// Handle the rest of initialData if any (unlikely for 0x00 packet)
if len(initialData) > 1 {
if _, err := conn.Write(initialData[1:]); err != nil {
return
}
}
go func() {
_, _ = io.Copy(stream, conn)
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
go c.pumpFromMux(ctx, sid, conn)
c.pumpToMux(sid, conn)
_ = ctx // keep signature
}
func (c *Client) socks5Handshake(conn net.Conn) error {
@@ -459,62 +419,6 @@ func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
return addr, port, nil
}
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
buf := make([]byte, 16384)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
for !c.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
}
}
func (c *Client) pumpFromMux(ctx context.Context, sid uint16, conn net.Conn) {
defer c.mux.CleanupDataChannel(sid)
dataReady := c.mux.WaitForData(sid)
for {
select {
case <-ctx.Done():
return
case <-dataReady:
data := c.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {
return
}
c.mux.HandleFrame(plaintext)
}
func (c *Client) canSendData() bool {
for _, tr := range c.links {
if !tr.CanSend() {
return false
}
}
return true
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}

View File

@@ -1,477 +0,0 @@
// Package mux provides a multiplexer for multiple streams over a single connection.
package mux
import (
"encoding/binary"
"errors"
"fmt"
"math"
"sync"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
var (
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
ErrClientResetID = errors.New("client reset requires a non-zero client id")
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
ErrDataTooLarge = errors.New("data chunk too large")
)
const (
// HeaderSize is the size of the frame header in bytes.
HeaderSize = 12
// ControlStreamID is a special stream ID used for control frames.
ControlStreamID uint16 = 0xFFFF
// ControlResetClient is a control frame type used to signal a client reset.
ControlResetClient uint32 = 1
// FrameTypeData is a marker for data frames.
FrameTypeData uint16 = 0
// FrameTypeControl is a marker for control frames.
FrameTypeControl uint16 = 0xFFFF
)
// ControlFrame represents a control message between multiplexers.
type ControlFrame struct {
ClientID uint32
Type uint32
}
// Stream represents a single multiplexed data stream.
type Stream struct {
ID uint16
ClientID uint32
recvBuf []byte
closed bool
mu sync.Mutex
nextSeq uint32
outOfOrder map[uint32][]byte
}
// RecvBuf returns the current receive buffer content.
func (s *Stream) RecvBuf() []byte {
s.mu.Lock()
defer s.mu.Unlock()
return s.recvBuf
}
// Multiplexer coordinates multiple Streams over a single transport channel.
type Multiplexer struct {
streams map[uint16]*Stream
nextID uint16
clientID uint32
onSend func([]byte) error
mu sync.RWMutex
maxStreams int
maxBufferSize int
dataReady map[uint16]chan struct{}
dataReadyMu sync.Mutex
sendSeq map[uint16]uint32
sendSeqMu sync.Mutex
// bufferCond is used to wait for space in receive buffers
bufferCond *sync.Cond
}
// New creates a new Multiplexer instance.
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
m := &Multiplexer{
streams: make(map[uint16]*Stream),
nextID: 1,
clientID: clientID,
onSend: onSend,
maxStreams: 10000,
maxBufferSize: 32 * 1024 * 1024,
dataReady: make(map[uint16]chan struct{}),
sendSeq: make(map[uint16]uint32),
}
m.bufferCond = sync.NewCond(&m.mu)
return m
}
// OpenStream allocates and returns a new unique stream ID.
func (m *Multiplexer) OpenStream() uint16 {
m.mu.Lock()
defer m.mu.Unlock()
for {
sid := m.nextID
m.nextID++
if m.nextID == 0 {
m.nextID = 1
}
if _, exists := m.streams[sid]; !exists {
m.streams[sid] = &Stream{
ID: sid,
recvBuf: make([]byte, 0),
nextSeq: 0,
outOfOrder: make(map[uint32][]byte),
}
return sid
}
}
}
// SendData fragments and sends data over a specific stream.
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
m.mu.RLock()
stream, exists := m.streams[sid]
m.mu.RUnlock()
if !exists || stream.closed {
return nil
}
const chunkSize = 7000
for i := 0; i < len(data); i += chunkSize {
end := i + chunkSize
if end > len(data) {
end = len(data)
}
chunk := data[i:end]
m.sendSeqMu.Lock()
seq := m.sendSeq[sid]
m.sendSeq[sid]++
m.sendSeqMu.Unlock()
if len(chunk) > math.MaxUint16 {
return ErrDataTooLarge
}
frame := make([]byte, HeaderSize+len(chunk))
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
binary.BigEndian.PutUint32(frame[8:12], seq)
copy(frame[HeaderSize:], chunk)
if err := m.onSend(frame); err != nil {
return fmt.Errorf("onSend failed: %w", err)
}
}
return nil
}
// CloseStream signals that a stream should be terminated.
func (m *Multiplexer) CloseStream(sid uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
if stream, exists := m.streams[sid]; exists {
stream.closed = true
}
m.sendSeqMu.Lock()
delete(m.sendSeq, sid)
m.sendSeqMu.Unlock()
// Notify anyone waiting for buffer space that a stream is closed
m.bufferCond.Broadcast()
frame := make([]byte, HeaderSize)
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], 0)
binary.BigEndian.PutUint32(frame[8:12], 0)
if err := m.onSend(frame); err != nil {
return fmt.Errorf("onSend failed: %w", err)
}
return nil
}
// SendClientReset sends a control frame to reset all streams for this client.
func (m *Multiplexer) SendClientReset() error {
if m.clientID == 0 {
return ErrClientResetID
}
if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil {
return fmt.Errorf("onSend failed: %w", err)
}
return nil
}
// BuildControlFrame constructs a raw control frame.
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
frame := make([]byte, HeaderSize)
binary.BigEndian.PutUint32(frame[0:4], clientID)
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control
binary.BigEndian.PutUint32(frame[8:12], controlType)
return frame
}
// ParseControlFrame attempts to extract control information from a frame.
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
if len(frame) < HeaderSize {
return ControlFrame{}, false
}
sid := binary.BigEndian.Uint16(frame[4:6])
length := binary.BigEndian.Uint16(frame[6:8])
if sid != ControlStreamID || length != 0xFFFF {
return ControlFrame{}, false
}
return ControlFrame{
ClientID: binary.BigEndian.Uint32(frame[0:4]),
Type: binary.BigEndian.Uint32(frame[8:12]),
}, true
}
// HandleFrame processes an incoming frame from the transport.
func (m *Multiplexer) HandleFrame(frame []byte) {
control, ok := ParseControlFrame(frame)
if ok {
m.handleControlFrame(control)
return
}
if len(frame) < HeaderSize {
return
}
clientID := binary.BigEndian.Uint32(frame[0:4])
sid := binary.BigEndian.Uint16(frame[4:6])
length := binary.BigEndian.Uint16(frame[6:8])
seq := binary.BigEndian.Uint32(frame[8:12])
if length == 0 {
m.handleCloseStreamFrame(sid, clientID)
return
}
if len(frame) < HeaderSize+int(length) {
return
}
m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)])
}
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
stream.closed = true
m.bufferCond.Broadcast()
}
}
func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) {
m.mu.Lock()
defer m.mu.Unlock()
stream := m.getOrCreateStream(sid, clientID)
if stream == nil {
return
}
if seq == stream.nextSeq {
if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil {
s.recvBuf = append(s.recvBuf, data...)
s.nextSeq++
m.applyOutOfOrder(s, sid, clientID)
m.notifyDataReady(sid)
}
} else if seq > stream.nextSeq {
if len(stream.outOfOrder) < 100 {
stream.outOfOrder[seq] = append([]byte(nil), data...)
}
}
}
func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
stream, exists := m.streams[sid]
if !exists {
if len(m.streams) >= m.maxStreams {
return nil
}
stream = &Stream{
ID: sid,
ClientID: clientID,
recvBuf: make([]byte, 0),
nextSeq: 0,
outOfOrder: make(map[uint32][]byte),
}
m.streams[sid] = stream
return stream
}
if stream.ClientID != clientID {
stream.ClientID = clientID
stream.recvBuf = make([]byte, 0)
stream.closed = false
stream.nextSeq = 0
stream.outOfOrder = make(map[uint32][]byte)
m.bufferCond.Broadcast()
}
return stream
}
func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) {
for {
nextData, ok := stream.outOfOrder[stream.nextSeq]
if !ok {
break
}
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
return
}
stream.recvBuf = append(stream.recvBuf, nextData...)
delete(stream.outOfOrder, stream.nextSeq)
stream.nextSeq++
logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
}
}
func (m *Multiplexer) notifyDataReady(sid uint16) {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
if ch, ok := m.dataReady[sid]; ok {
select {
case ch <- struct{}{}:
default:
}
}
}
func (m *Multiplexer) handleControlFrame(control ControlFrame) {
switch control.Type {
case ControlResetClient:
m.ResetClient(control.ClientID)
default:
logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
}
}
// ResetClient closes and removes all streams associated with a client ID.
func (m *Multiplexer) ResetClient(clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
for streamSid, stream := range m.streams {
if stream.ClientID == clientID {
stream.closed = true
delete(m.streams, streamSid)
}
}
m.bufferCond.Broadcast()
}
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
for {
stream, ok := m.streams[sid]
if !ok || stream.ClientID != clientID || stream.closed {
return nil
}
if len(stream.recvBuf)+need <= m.maxBufferSize {
return stream
}
// Wait for space to become available
m.bufferCond.Wait()
}
}
// ReadStream retrieves and clears the current receive buffer for a stream.
func (m *Multiplexer) ReadStream(sid uint16) []byte {
m.mu.Lock()
defer m.mu.Unlock()
stream, exists := m.streams[sid]
if !exists || len(stream.recvBuf) == 0 {
return nil
}
data := stream.recvBuf
stream.recvBuf = make([]byte, 0)
// Notify producers that space is now available
m.bufferCond.Broadcast()
return data
}
// StreamClosed returns true if the stream is closed or doesn't exist.
func (m *Multiplexer) StreamClosed(sid uint16) bool {
m.mu.RLock()
defer m.mu.RUnlock()
stream, exists := m.streams[sid]
return !exists || stream.closed
}
// GetStreams returns a list of all active stream IDs.
func (m *Multiplexer) GetStreams() []uint16 {
m.mu.RLock()
defer m.mu.RUnlock()
sids := make([]uint16, 0, len(m.streams))
for sid := range m.streams {
sids = append(sids, sid)
}
return sids
}
// GetStream returns the Stream object for a given ID.
func (m *Multiplexer) GetStream(sid uint16) *Stream {
m.mu.RLock()
defer m.mu.RUnlock()
return m.streams[sid]
}
// Reset clears all multiplexer state and closes all streams.
func (m *Multiplexer) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
for _, stream := range m.streams {
stream.closed = true
}
m.streams = make(map[uint16]*Stream)
m.nextID = 1
m.sendSeqMu.Lock()
m.sendSeq = make(map[uint16]uint32)
m.sendSeqMu.Unlock()
m.bufferCond.Broadcast()
}
// UpdateSendFunc updates the function used to transmit raw frames.
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
m.mu.Lock()
defer m.mu.Unlock()
m.onSend = onSend
}
// WaitForData returns a channel that signals when new data is available for a stream.
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
if _, ok := m.dataReady[sid]; !ok {
m.dataReady[sid] = make(chan struct{}, 1)
}
return m.dataReady[sid]
}
// CleanupDataChannel removes the data notification channel for a stream.
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
if ch, ok := m.dataReady[sid]; ok {
close(ch)
delete(m.dataReady, sid)
}
}

119
internal/muxconn/conn.go Normal file
View File

@@ -0,0 +1,119 @@
// Package muxconn adapts a link.Link into an io.ReadWriteCloser suitable for
// driving a smux session. The wrapper applies AEAD on every wire-bound write
// and inverts it on every received message before exposing the bytes as a
// byte stream.
//
// Link semantics are message-oriented: each Send produces exactly one OnData
// on the peer. smux operates on a pure byte stream (header + payload may be
// glued or split across reads). We bridge by:
//
// - Treating each Push as an opaque chunk appended to an internal byte
// buffer that Read drains in arbitrary slices.
// - Letting smux's sendLoop call Write once per frame; we encrypt and hand
// the whole buffer to the link as a single message. Length boundaries
// are preserved end-to-end by the transport (KCP length-prefix framing
// in vp8channel, native message boundaries in datachannel).
package muxconn
import (
"errors"
"io"
"sync"
"time"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
)
// ErrClosed is returned from Read/Write after the conn has been closed.
var ErrClosed = errors.New("muxconn: closed")
// Conn is an io.ReadWriteCloser over a link.Link with optional AEAD wrapping.
type Conn struct {
ln link.Link
cipher *crypto.Cipher
mu sync.Mutex
cond *sync.Cond
buf []byte
closed bool
}
// New wires a Conn over the given link. Push must be set as the link's OnData
// callback before this conn is used.
func New(ln link.Link, cipher *crypto.Cipher) *Conn {
c := &Conn{ln: ln, cipher: cipher}
c.cond = sync.NewCond(&c.mu)
return c
}
// Push hands an encrypted wire payload (one OnData event) to the conn.
func (c *Conn) Push(ciphertext []byte) {
pt, err := c.cipher.Decrypt(ciphertext)
if err != nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.buf = append(c.buf, pt...)
c.cond.Broadcast()
}
// Read implements io.Reader. Blocks until at least one byte is available.
func (c *Conn) Read(p []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
for !c.closed && len(c.buf) == 0 {
c.cond.Wait()
}
if len(c.buf) == 0 {
return 0, io.EOF
}
n := copy(p, c.buf)
c.buf = c.buf[n:]
return n, nil
}
// Write encrypts p and ships it to the link as a single message. Blocks while
// the link signals back-pressure.
func (c *Conn) Write(p []byte) (int, error) {
for {
if c.isClosed() {
return 0, ErrClosed
}
if c.ln.CanSend() {
break
}
time.Sleep(10 * time.Millisecond)
}
enc, err := c.cipher.Encrypt(p)
if err != nil {
return 0, err
}
if err := c.ln.Send(enc); err != nil {
return 0, err
}
return len(p), nil
}
// Close unblocks any pending Read with io.EOF.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
c.cond.Broadcast()
return nil
}
func (c *Conn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}

View File

@@ -11,44 +11,32 @@ import (
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/xtaci/smux"
)
var (
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the encryption key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
// ErrNoLinks is returned when no links are available.
ErrNoLinks = errors.New("no links available")
// ErrDialProxy is returned when dialing the proxy fails.
ErrDialProxy = errors.New("failed to dial proxy")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
links []link.Link
ln link.Link
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
streamPumps map[uint16]net.Conn
pumpMu sync.Mutex
linkIdx atomic.Uint32
activeClients atomic.Int32
conn *muxconn.Conn
session *smux.Session
sessMu sync.RWMutex
wg sync.WaitGroup
dnsServer string
resolver *net.Resolver
@@ -97,25 +85,22 @@ func Run(
s := &Server{
cipher: cipher,
connections: make(map[uint16]net.Conn),
streamPumps: make(map[uint16]net.Conn),
links: make([]link.Link, 0),
dnsServer: dnsServer,
socksProxyAddr: socksProxyAddr,
socksProxyPort: socksProxyPort,
}
s.setupResolver()
s.setupMux()
const linkCount = 1
for i := range linkCount {
if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil {
return fmt.Errorf("addLink failed: %w", err)
}
if err := s.bringUpLink(
runCtx, linkName, transportName, carrierName, roomURL, cancel,
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
vp8FPS, vp8BatchSize,
); err != nil {
return err
}
err = s.runLoop(runCtx)
err = s.serve(runCtx)
s.shutdown()
s.wg.Wait()
@@ -136,12 +121,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) {
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
cipher, err := crypto.NewCipher(string(key))
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
@@ -158,51 +138,30 @@ func (s *Server) setupResolver() {
}
}
func (s *Server) setupMux() {
s.mux = mux.New(0, func(frame []byte) error {
for {
canSend := true
for _, ln := range s.links {
if !ln.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.links) == 0 {
return ErrNoLinks
}
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
return s.links[idx].Send(encrypted)
})
// smuxConfig mirrors the client side. Both peers must agree on Version and
// MaxFrameSize.
func smuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.MaxFrameSize = 32768
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
cfg.KeepAliveTimeout = 60 * time.Second
return cfg
}
func (s *Server) addLink(
func (s *Server) bringUpLink(
ctx context.Context,
linkName,
transportName,
carrierName,
roomURL string,
linkID int,
linkName, transportName, carrierName, roomURL string,
cancel context.CancelFunc,
videoWidth, videoHeight, videoFPS int,
videoBitrate, videoHW string,
videoQRSize int,
videoQRRecovery string,
videoCodec string,
videoTileModule int,
videoTileRS int,
vp8FPS int,
vp8BatchSize int,
videoTileModule, videoTileRS int,
vp8FPS, vp8BatchSize int,
) error {
ln, err := link.New(ctx, linkName, link.Config{
Transport: transportName,
@@ -229,22 +188,21 @@ func (s *Server) addLink(
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
s.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Server link %d reported conference end: %s", linkID, reason)
logger.Infof("Server link reported conference end: %s", reason)
cancel()
})
s.links = append(s.links, ln)
ln.SetReconnectCallback(func() { s.handleReconnect() })
ln.SetReconnectCallback(func() {
s.handleLinkReconnect(linkID)
})
logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName)
logger.Infof("Connecting link via %s/%s/%s...", linkName, transportName, carrierName)
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
logger.Infof("Link %d connected", linkID)
logger.Infof("Link connected")
s.installSession()
s.wg.Add(1)
go func() {
@@ -254,30 +212,195 @@ func (s *Server) addLink(
return nil
}
func (s *Server) handleLinkReconnect(linkID int) {
logger.Infof("link %d reconnect event", linkID)
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
delete(s.connections, sid)
func (s *Server) installSession() {
conn := muxconn.New(s.ln, s.cipher)
sess, err := smux.Server(conn, smuxConfig())
if err != nil {
logger.Warnf("smux server init failed: %v", err)
return
}
s.connMu.Unlock()
s.sessMu.Lock()
s.conn = conn
s.session = sess
s.sessMu.Unlock()
}
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
func (s *Server) handleReconnect() {
logger.Infof("server link reconnect — tearing down smux session")
s.sessMu.Lock()
if s.session != nil {
_ = s.session.Close()
s.session = nil
}
if s.conn != nil {
_ = s.conn.Close()
s.conn = nil
}
s.sessMu.Unlock()
s.installSession()
}
func (s *Server) onData(data []byte) {
s.sessMu.RLock()
conn := s.conn
s.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
// The loop tolerates session bounces (reconnects) by waiting until a fresh
// session is installed instead of terminating the server.
func (s *Server) serve(ctx context.Context) error {
for {
if ctx.Err() != nil {
return nil
}
s.sessMu.RLock()
sess := s.session
s.sessMu.RUnlock()
if sess == nil {
select {
case <-ctx.Done():
return nil
case <-time.After(50 * time.Millisecond):
continue
}
}
stream, err := sess.AcceptStream()
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
// Session is torn down (reconnect or close). If we're shutting
// down, exit; otherwise wait for a new session and retry.
if ctx.Err() != nil {
return nil
}
logger.Infof("AcceptStream returned %v — waiting for new session", err)
time.Sleep(100 * time.Millisecond)
continue
}
if len(s.links) == 0 {
return ErrNoLinks
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleStream(ctx, stream)
}()
}
}
func (s *Server) shutdown() {
s.sessMu.Lock()
if s.session != nil {
_ = s.session.Close()
}
if s.conn != nil {
_ = s.conn.Close()
}
s.sessMu.Unlock()
if s.ln != nil {
_ = s.ln.Close()
}
}
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
defer stream.Close()
// Read the connect JSON. The client writes the whole JSON in one
// stream.Write so it usually arrives intact; tolerate fragmentation
// by reading incrementally up to a sane cap.
const maxConnReq = 4096
header := make([]byte, 0, 256)
tmp := make([]byte, 256)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
for {
n, err := stream.Read(tmp)
if n > 0 {
header = append(header, tmp[:n]...)
if req, ok := parseConnectRequest(header); ok {
_ = stream.SetReadDeadline(time.Time{})
s.dispatch(stream, req)
return
}
}
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
return s.links[idx].Send(encrypted)
})
s.mux.Reset()
if err != nil {
return
}
if len(header) > maxConnReq {
return
}
}
}
func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
var req ConnectRequest
if err := json.Unmarshal(buf, &req); err != nil {
return req, false
}
if req.Cmd != "connect" {
return req, false
}
return req, true
}
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
logger.Infof("sid=%d connect %s", stream.ID(), addr)
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
return
}
defer conn.Close()
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
if _, err := stream.Write([]byte{0x00}); err != nil {
return
}
go func() {
_, _ = io.Copy(stream, conn)
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
@@ -318,336 +441,3 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int)
return nil
}
func (s *Server) onData(data []byte) {
plaintext, err := s.cipher.Decrypt(data)
if err != nil {
logger.Debugf("Decrypt error: %v", err)
return
}
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID)
s.closeClientConnections(control.ClientID)
}
s.mux.HandleFrame(plaintext)
}
func (s *Server) closeClientConnections(clientID uint32) {
s.connMu.Lock()
defer s.connMu.Unlock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
_ = conn.Close()
}
delete(s.connections, streamSid)
}
}
}
func (s *Server) runLoop(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
s.processMuxStreams(ctx)
}
}
}
func (s *Server) shutdown() {
s.connMu.Lock()
for _, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
}
s.connMu.Unlock()
s.pumpMu.Lock()
for _, conn := range s.streamPumps {
if conn != nil {
_ = conn.Close()
}
}
s.pumpMu.Unlock()
for i, tr := range s.links {
logger.Infof("closing link %d", i)
_ = tr.Close()
}
}
func (s *Server) processMuxStreams(ctx context.Context) {
sids := s.mux.GetStreams()
for _, sid := range sids {
if s.mux.StreamClosed(sid) {
s.closeStreamConnection(sid)
continue
}
if s.hasConnection(sid) {
continue
}
data := s.mux.ReadStream(sid)
if len(data) == 0 {
continue
}
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
}
}
func (s *Server) hasConnection(sid uint16) bool {
s.connMu.RLock()
defer s.connMu.RUnlock()
return s.connections[sid] != nil
}
func (s *Server) closeStreamConnection(sid uint16) {
s.connMu.Lock()
conn := s.connections[sid]
if conn != nil {
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
s.connMu.Lock()
conn := s.connections[sid]
if conn == expected {
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
s.pumpMu.Lock()
defer s.pumpMu.Unlock()
if current := s.streamPumps[sid]; current == conn {
return false
} else if current != nil {
_ = current.Close()
}
s.streamPumps[sid] = conn
return true
}
func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
s.pumpMu.Lock()
if s.streamPumps[sid] == conn {
delete(s.streamPumps, sid)
}
s.pumpMu.Unlock()
}
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
s.closeStreamConnection(sid)
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
_ = s.mux.CloseStream(sid)
return
}
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed)
s.activeClients.Add(1)
_ = s.mux.SendData(sid, []byte{0x00})
s.startStreamPump(ctx, sid, conn)
go s.pumpToMux(sid, conn)
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
s.activeClients.Add(-1)
_ = s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
// Decoupling queue: Read goroutine pushes here, sender goroutine drains
// to mux.SendData. Without this, slow channel back-pressure stalls the
// upstream Read which can cause TCP receive window to collapse to zero
// and effectively wedge the connection (peer stops sending and never
// resumes even though our channel is healthy).
type chunk struct{ data []byte }
queue := make(chan chunk, 64)
doneSender := make(chan struct{})
go func() {
defer close(doneSender)
for c := range queue {
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := s.mux.SendData(sid, c.data); err != nil {
return
}
}
}()
// queueHasSpace blocks until the decoupling queue has room or the
// sender goroutine has exited. We wait here *before* arming the
// upstream read deadline so that channel back-pressure isn't billed
// to the socket as idle time and doesn't trip a spurious i/o timeout.
queueHasSpace := func() bool {
for {
if len(queue) < cap(queue) {
return true
}
select {
case <-doneSender:
return false
case <-time.After(10 * time.Millisecond):
}
}
}
buf := make([]byte, 16384)
// Idle timeout for genuinely dead upstreams. Only armed when we are
// actively waiting on the socket (queue has space). During internal
// back-pressure the deadline is not in effect, so flow-control pauses
// don't get mis-classified as remote death.
const idleReadTimeout = 60 * time.Second
for {
if !queueHasSpace() {
close(queue)
<-doneSender
return
}
// Arm the deadline only when we actually want bytes from the peer.
_ = conn.SetReadDeadline(time.Now().Add(idleReadTimeout))
n, err := conn.Read(buf)
if err != nil {
close(queue)
<-doneSender
return
}
// Clear the deadline so it doesn't fire while we are blocked in
// queueHasSpace() on the next iteration (back-pressure path).
_ = conn.SetReadDeadline(time.Time{})
// Copy because buf is reused on next Read.
c := make([]byte, n)
copy(c, buf[:n])
// Guaranteed non-blocking thanks to queueHasSpace() above (we are
// the sole producer); the blocking fallback is just defensive.
select {
case queue <- chunk{data: c}:
default:
queue <- chunk{data: c}
}
}
}
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
if !s.markStreamPump(sid, conn) {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.unmarkStreamPump(sid, conn)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
data := s.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = s.mux.CloseStream(sid)
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
if s.mux.StreamClosed(sid) {
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
}
}()
}
func (s *Server) canSendData() bool {
for _, tr := range s.links {
if !tr.CanSend() {
return false
}
}
return true
}