diff --git a/go.mod b/go.mod index 0b41c77..ee5e830 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/tjfoc/gmsm v1.4.1 // indirect github.com/twitchtv/twirp v8.1.3+incompatible // indirect github.com/wlynxg/anet v0.0.5 // indirect + github.com/xtaci/smux v1.5.57 // indirect github.com/zeebo/xxh3 v1.1.0 // indirect go.opentelemetry.io/otel v1.40.0 // indirect go.uber.org/atomic v1.11.0 // indirect diff --git a/go.sum b/go.sum index 3d6d2e2..4ac3421 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ github.com/xtaci/kcp-go/v5 v5.6.72 h1:FLaQPalgpufJYQRk0OK+gErEhXGLUPjv6FSRPrFR8L github.com/xtaci/kcp-go/v5 v5.6.72/go.mod h1:9O3D8WR+cyyUjGiTILYfg17vn72otWuXK2AFfqIe6CM= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE= +github.com/xtaci/smux v1.5.57 h1:N72VbGoSYxgcm6mPOYX0QzEZNVD3UI/JlVvAtXF+WrY= +github.com/xtaci/smux v1.5.57/go.mod h1:IGQ9QYrBphmb/4aTnLEcJby0TNr3NV+OslIOMrX825Q= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89 h1:ytA0RfQZTYfjqFA9lBJMX1DTnXpTuKg0nf4udgdpunE= github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89/go.mod h1:OUqzZNoXsg+ccaiAnSe0t4f8qc0W/cFx6io0lWsE1Gw= diff --git a/internal/client/client.go b/internal/client/client.go index 51bc4d1..ef8ab6e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" - "crypto/rand" "encoding/binary" "encoding/hex" "encoding/json" @@ -17,8 +16,9 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" - "github.com/openlibrecommunity/olcrtc/internal/mux" + "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/xtaci/smux" ) var ( @@ -26,21 +26,16 @@ var ( ErrConnectFailed = errors.New("tunnel connection failed") // ErrProxyAuth is returned when SOCKS proxy authentication fails. ErrProxyAuth = errors.New("SOCKS proxy auth failed") - // ErrMuxExited is returned when the multiplexer loop exits unexpectedly. - ErrMuxExited = errors.New("multiplexer loop exited") - // ErrNoAvailableLinks is returned when no links are ready for sending. - ErrNoAvailableLinks = errors.New("no available links") ) // Client handles local SOCKS5 connections and tunnels them to the server. type Client struct { - links []link.Link - cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - clientID uint32 - dnsServer string + ln link.Link + cipher *crypto.Cipher + conn *muxconn.Conn + session *smux.Session + sessMu sync.RWMutex + dnsServer string } // Run starts the client with the specified parameters. @@ -105,37 +100,27 @@ func RunWithReady( return fmt.Errorf("setupCipher failed: %w", err) } - clientIDBytes := make([]byte, 4) - if _, err := rand.Read(clientIDBytes); err != nil { - return fmt.Errorf("failed to generate client ID: %w", err) - } - clientID := binary.BigEndian.Uint32(clientIDBytes) + c := &Client{cipher: cipher, dnsServer: dnsServer} - c := &Client{ - cipher: cipher, - connections: make(map[uint16]net.Conn), - links: make([]link.Link, 0), - clientID: clientID, - dnsServer: dnsServer, - } - - c.setupMux() - - const linkCount = 1 - for i := range linkCount { - if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil { - return fmt.Errorf("addLink failed: %w", err) - } + if err := c.bringUpLink( + runCtx, linkName, transportName, carrierName, roomURL, cancel, + dnsServer, "", 0, + videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, + videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, + vp8FPS, vp8BatchSize, + ); err != nil { + return err } + defer c.shutdown() lc := net.ListenConfig{} - ln, err := lc.Listen(runCtx, "tcp4", localAddr) + listener, err := lc.Listen(runCtx, "tcp4", localAddr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } - defer ln.Close() + defer listener.Close() - logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID) + logger.Infof("SOCKS5 server listening on %s", localAddr) if onReady != nil { onReady() @@ -143,96 +128,30 @@ func RunWithReady( errCh := make(chan error, 1) go func() { - errCh <- c.acceptLoop(runCtx, ln) + errCh <- c.acceptLoop(runCtx, listener) }() select { case <-runCtx.Done(): - c.shutdown() return nil case err := <-errCh: return err } } -func (c *Client) shutdown() { - c.connMu.Lock() - for _, conn := range c.connections { - if conn != nil { - _ = conn.Close() - } - } - c.connMu.Unlock() - - for i, ln := range c.links { - logger.Infof("closing link %d", i) - _ = ln.Close() - } -} - -func setupCipher(keyHex string) (*crypto.Cipher, error) { - key, err := hex.DecodeString(keyHex) - if err != nil { - return nil, fmt.Errorf("failed to decode key: %w", err) - } - if len(key) != 32 { - return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key)) - } - - cipher, err := crypto.NewCipher(string(key)) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - return cipher, nil -} - -func (c *Client) setupMux() { - c.mux = mux.New(c.clientID, func(frame []byte) error { - for { - canSend := true - for _, ln := range c.links { - if !ln.CanSend() { - canSend = false - break - } - } - if canSend { - break - } - time.Sleep(10 * time.Millisecond) - } - - encrypted, err := c.cipher.Encrypt(frame) - if err != nil { - return err - } - if len(c.links) == 0 { - return ErrNoAvailableLinks - } - return c.links[0].Send(encrypted) - }) -} - -func (c *Client) addLink( +func (c *Client) bringUpLink( ctx context.Context, - linkName, - transportName, - carrierName, - roomURL string, - linkID int, + linkName, transportName, carrierName, roomURL string, cancel context.CancelFunc, - dnsServer, - socksProxyAddr string, + dnsServer, socksProxyAddr string, socksProxyPort int, videoWidth, videoHeight, videoFPS int, videoBitrate, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, - videoTileModule int, - videoTileRS int, - vp8FPS int, - vp8BatchSize int, + videoTileModule, videoTileRS int, + vp8FPS, vp8BatchSize int, ) error { ln, err := link.New(ctx, linkName, link.Config{ Transport: transportName, @@ -259,56 +178,104 @@ func (c *Client) addLink( if err != nil { return fmt.Errorf("failed to create link: %w", err) } + c.ln = ln ln.SetEndedCallback(func(reason string) { - logger.Infof("Client link %d reported conference end: %s", linkID, reason) + logger.Infof("Client link reported conference end: %s", reason) cancel() }) - c.links = append(c.links, ln) - - ln.SetReconnectCallback(func() { - c.handleLinkReconnect(linkID) - }) + ln.SetReconnectCallback(func() { c.handleReconnect() }) if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } + c.conn = muxconn.New(ln, c.cipher) + sess, err := smux.Client(c.conn, smuxConfig()) + if err != nil { + return fmt.Errorf("smux client: %w", err) + } + c.sessMu.Lock() + c.session = sess + c.sessMu.Unlock() + go ln.WatchConnection(ctx) return nil } -func (c *Client) handleLinkReconnect(linkID int) { - logger.Infof("link %d reconnect event", linkID) - c.sendResetSignal() - - c.connMu.Lock() - for sid, conn := range c.connections { - if conn != nil { - _ = conn.Close() - } - delete(c.connections, sid) - } - c.connMu.Unlock() - - c.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := c.cipher.Encrypt(frame) - if err != nil { - return err - } - if len(c.links) == 0 { - return ErrNoAvailableLinks - } - return c.links[0].Send(encrypted) - }) - c.mux.Reset() +// smuxConfig returns the tuned smux config used on both ends. +func smuxConfig() *smux.Config { + cfg := smux.DefaultConfig() + cfg.Version = 2 + cfg.MaxFrameSize = 32768 + cfg.MaxReceiveBuffer = 16 * 1024 * 1024 + cfg.MaxStreamBuffer = 1024 * 1024 + cfg.KeepAliveInterval = 10 * time.Second + cfg.KeepAliveTimeout = 60 * time.Second + return cfg } -func (c *Client) sendResetSignal() { - resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient) - encrypted, _ := c.cipher.Encrypt(resetFrame) - if len(c.links) > 0 { - _ = c.links[0].Send(encrypted) +func (c *Client) handleReconnect() { + logger.Infof("client link reconnect — tearing down smux session") + c.sessMu.Lock() + if c.session != nil { + _ = c.session.Close() + c.session = nil + } + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + c.sessMu.Unlock() + // New SOCKS5 connections will fail until the link comes back up; the + // caller will reissue them. Existing streams die with the smux session. + c.conn = muxconn.New(c.ln, c.cipher) + sess, err := smux.Client(c.conn, smuxConfig()) + if err != nil { + logger.Warnf("smux re-init failed: %v", err) + return + } + c.sessMu.Lock() + c.session = sess + c.sessMu.Unlock() +} + +func (c *Client) shutdown() { + c.sessMu.Lock() + if c.session != nil { + _ = c.session.Close() + } + if c.conn != nil { + _ = c.conn.Close() + } + c.sessMu.Unlock() + if c.ln != nil { + _ = c.ln.Close() + } +} + +func setupCipher(keyHex string) (*crypto.Cipher, error) { + key, err := hex.DecodeString(keyHex) + if err != nil { + return nil, fmt.Errorf("failed to decode key: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key)) + } + + cipher, err := crypto.NewCipher(string(key)) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + return cipher, nil +} + +func (c *Client) onData(data []byte) { + c.sessMu.RLock() + conn := c.conn + c.sessMu.RUnlock() + if conn != nil { + conn.Push(data) } } @@ -340,19 +307,23 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { return } - sid := c.mux.OpenStream() - defer c.mux.CloseStream(sid) + c.sessMu.RLock() + sess := c.session + c.sessMu.RUnlock() + if sess == nil || sess.IsClosed() { + _, _ = conn.Write(replyHostUnreachable()) + return + } - c.connMu.Lock() - c.connections[sid] = conn - c.connMu.Unlock() - defer func() { - c.connMu.Lock() - delete(c.connections, sid) - c.connMu.Unlock() - }() + stream, err := sess.OpenStream() + if err != nil { + logger.Warnf("OpenStream failed: %v", err) + _, _ = conn.Write(replyHostUnreachable()) + return + } + defer stream.Close() - logger.Infof("sid=%d tunnel to %s:%d", sid, targetAddr, targetPort) + logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort) connectReq, _ := json.Marshal(map[string]any{ "cmd": "connect", @@ -360,45 +331,34 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { "port": targetPort, }) - if err := c.mux.SendData(sid, connectReq); err != nil { - logger.Warnf("sid=%d tunnel setup failed: %v", sid, err) + _ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if _, err := stream.Write(connectReq); err != nil { + logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err) _, _ = conn.Write(replyHostUnreachable()) return } + _ = stream.SetWriteDeadline(time.Time{}) - readyTimer := time.NewTimer(10 * time.Second) - defer readyTimer.Stop() - - dataReady := c.mux.WaitForData(sid) - - var initialData []byte - select { - case <-readyTimer.C: - logger.Warnf("sid=%d tunnel setup failed: timeout waiting for remote ready", sid) + ack := make([]byte, 1) + _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) + if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 { + logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack) _, _ = conn.Write(replyHostUnreachable()) return - case <-dataReady: - initialData = c.mux.ReadStream(sid) - if len(initialData) == 0 || initialData[0] != 0x00 { - logger.Warnf("sid=%d tunnel setup failed: invalid remote ready", sid) - _, _ = conn.Write(replyHostUnreachable()) - return - } } + _ = stream.SetReadDeadline(time.Time{}) if _, err := conn.Write(replySuccess()); err != nil { return } - // Handle the rest of initialData if any (unlikely for 0x00 packet) - if len(initialData) > 1 { - if _, err := conn.Write(initialData[1:]); err != nil { - return - } - } + go func() { + _, _ = io.Copy(stream, conn) + _ = stream.Close() + }() + _, _ = io.Copy(conn, stream) - go c.pumpFromMux(ctx, sid, conn) - c.pumpToMux(sid, conn) + _ = ctx // keep signature } func (c *Client) socks5Handshake(conn net.Conn) error { @@ -459,62 +419,6 @@ func (c *Client) socks5Request(conn net.Conn) (string, int, error) { return addr, port, nil } -func (c *Client) pumpToMux(sid uint16, conn net.Conn) { - buf := make([]byte, 16384) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - - for !c.canSendData() { - time.Sleep(20 * time.Millisecond) - } - - if err := c.mux.SendData(sid, buf[:n]); err != nil { - return - } - } -} - -func (c *Client) pumpFromMux(ctx context.Context, sid uint16, conn net.Conn) { - defer c.mux.CleanupDataChannel(sid) - dataReady := c.mux.WaitForData(sid) - for { - select { - case <-ctx.Done(): - return - case <-dataReady: - data := c.mux.ReadStream(sid) - if len(data) > 0 { - if _, err := conn.Write(data); err != nil { - return - } - } - if c.mux.StreamClosed(sid) { - return - } - } - } -} - -func (c *Client) onData(data []byte) { - plaintext, err := c.cipher.Decrypt(data) - if err != nil { - return - } - c.mux.HandleFrame(plaintext) -} - -func (c *Client) canSendData() bool { - for _, tr := range c.links { - if !tr.CanSend() { - return false - } - } - return true -} - func replySuccess() []byte { return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0} } diff --git a/internal/mux/mux.go b/internal/mux/mux.go deleted file mode 100644 index cfb26d2..0000000 --- a/internal/mux/mux.go +++ /dev/null @@ -1,477 +0,0 @@ -// Package mux provides a multiplexer for multiple streams over a single connection. -package mux - -import ( - "encoding/binary" - "errors" - "fmt" - "math" - "sync" - - "github.com/openlibrecommunity/olcrtc/internal/logger" -) - -var ( - // ErrClientResetID is returned when a client reset is attempted with a zero client ID. - ErrClientResetID = errors.New("client reset requires a non-zero client id") - // ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size. - ErrDataTooLarge = errors.New("data chunk too large") -) - -const ( - // HeaderSize is the size of the frame header in bytes. - HeaderSize = 12 - - // ControlStreamID is a special stream ID used for control frames. - ControlStreamID uint16 = 0xFFFF - - // ControlResetClient is a control frame type used to signal a client reset. - ControlResetClient uint32 = 1 - - // FrameTypeData is a marker for data frames. - FrameTypeData uint16 = 0 - // FrameTypeControl is a marker for control frames. - FrameTypeControl uint16 = 0xFFFF -) - -// ControlFrame represents a control message between multiplexers. -type ControlFrame struct { - ClientID uint32 - Type uint32 -} - -// Stream represents a single multiplexed data stream. -type Stream struct { - ID uint16 - ClientID uint32 - recvBuf []byte - closed bool - mu sync.Mutex - nextSeq uint32 - outOfOrder map[uint32][]byte -} - -// RecvBuf returns the current receive buffer content. -func (s *Stream) RecvBuf() []byte { - s.mu.Lock() - defer s.mu.Unlock() - return s.recvBuf -} - -// Multiplexer coordinates multiple Streams over a single transport channel. -type Multiplexer struct { - streams map[uint16]*Stream - nextID uint16 - clientID uint32 - onSend func([]byte) error - mu sync.RWMutex - maxStreams int - maxBufferSize int - dataReady map[uint16]chan struct{} - dataReadyMu sync.Mutex - sendSeq map[uint16]uint32 - sendSeqMu sync.Mutex - - // bufferCond is used to wait for space in receive buffers - bufferCond *sync.Cond -} - -// New creates a new Multiplexer instance. -func New(clientID uint32, onSend func([]byte) error) *Multiplexer { - m := &Multiplexer{ - streams: make(map[uint16]*Stream), - nextID: 1, - clientID: clientID, - onSend: onSend, - maxStreams: 10000, - maxBufferSize: 32 * 1024 * 1024, - dataReady: make(map[uint16]chan struct{}), - sendSeq: make(map[uint16]uint32), - } - m.bufferCond = sync.NewCond(&m.mu) - return m -} - -// OpenStream allocates and returns a new unique stream ID. -func (m *Multiplexer) OpenStream() uint16 { - m.mu.Lock() - defer m.mu.Unlock() - - for { - sid := m.nextID - m.nextID++ - if m.nextID == 0 { - m.nextID = 1 - } - - if _, exists := m.streams[sid]; !exists { - m.streams[sid] = &Stream{ - ID: sid, - recvBuf: make([]byte, 0), - nextSeq: 0, - outOfOrder: make(map[uint32][]byte), - } - return sid - } - } -} - -// SendData fragments and sends data over a specific stream. -func (m *Multiplexer) SendData(sid uint16, data []byte) error { - m.mu.RLock() - stream, exists := m.streams[sid] - m.mu.RUnlock() - - if !exists || stream.closed { - return nil - } - - const chunkSize = 7000 - - for i := 0; i < len(data); i += chunkSize { - end := i + chunkSize - if end > len(data) { - end = len(data) - } - - chunk := data[i:end] - - m.sendSeqMu.Lock() - seq := m.sendSeq[sid] - m.sendSeq[sid]++ - m.sendSeqMu.Unlock() - - if len(chunk) > math.MaxUint16 { - return ErrDataTooLarge - } - - frame := make([]byte, HeaderSize+len(chunk)) - binary.BigEndian.PutUint32(frame[0:4], m.clientID) - binary.BigEndian.PutUint16(frame[4:6], sid) - binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above - binary.BigEndian.PutUint32(frame[8:12], seq) - copy(frame[HeaderSize:], chunk) - - if err := m.onSend(frame); err != nil { - return fmt.Errorf("onSend failed: %w", err) - } - } - - return nil -} - -// CloseStream signals that a stream should be terminated. -func (m *Multiplexer) CloseStream(sid uint16) error { - m.mu.Lock() - defer m.mu.Unlock() - - if stream, exists := m.streams[sid]; exists { - stream.closed = true - } - - m.sendSeqMu.Lock() - delete(m.sendSeq, sid) - m.sendSeqMu.Unlock() - - // Notify anyone waiting for buffer space that a stream is closed - m.bufferCond.Broadcast() - - frame := make([]byte, HeaderSize) - binary.BigEndian.PutUint32(frame[0:4], m.clientID) - binary.BigEndian.PutUint16(frame[4:6], sid) - binary.BigEndian.PutUint16(frame[6:8], 0) - binary.BigEndian.PutUint32(frame[8:12], 0) - - if err := m.onSend(frame); err != nil { - return fmt.Errorf("onSend failed: %w", err) - } - return nil -} - -// SendClientReset sends a control frame to reset all streams for this client. -func (m *Multiplexer) SendClientReset() error { - if m.clientID == 0 { - return ErrClientResetID - } - if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil { - return fmt.Errorf("onSend failed: %w", err) - } - return nil -} - -// BuildControlFrame constructs a raw control frame. -func BuildControlFrame(clientID uint32, controlType uint32) []byte { - frame := make([]byte, HeaderSize) - binary.BigEndian.PutUint32(frame[0:4], clientID) - binary.BigEndian.PutUint16(frame[4:6], ControlStreamID) - binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control - binary.BigEndian.PutUint32(frame[8:12], controlType) - return frame -} - -// ParseControlFrame attempts to extract control information from a frame. -func ParseControlFrame(frame []byte) (ControlFrame, bool) { - if len(frame) < HeaderSize { - return ControlFrame{}, false - } - - sid := binary.BigEndian.Uint16(frame[4:6]) - length := binary.BigEndian.Uint16(frame[6:8]) - if sid != ControlStreamID || length != 0xFFFF { - return ControlFrame{}, false - } - - return ControlFrame{ - ClientID: binary.BigEndian.Uint32(frame[0:4]), - Type: binary.BigEndian.Uint32(frame[8:12]), - }, true -} - -// HandleFrame processes an incoming frame from the transport. -func (m *Multiplexer) HandleFrame(frame []byte) { - control, ok := ParseControlFrame(frame) - if ok { - m.handleControlFrame(control) - return - } - - if len(frame) < HeaderSize { - return - } - - clientID := binary.BigEndian.Uint32(frame[0:4]) - sid := binary.BigEndian.Uint16(frame[4:6]) - length := binary.BigEndian.Uint16(frame[6:8]) - seq := binary.BigEndian.Uint32(frame[8:12]) - - if length == 0 { - m.handleCloseStreamFrame(sid, clientID) - return - } - - if len(frame) < HeaderSize+int(length) { - return - } - - m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)]) -} - -func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) { - m.mu.Lock() - defer m.mu.Unlock() - if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID { - stream.closed = true - m.bufferCond.Broadcast() - } -} - -func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) { - m.mu.Lock() - defer m.mu.Unlock() - - stream := m.getOrCreateStream(sid, clientID) - if stream == nil { - return - } - - if seq == stream.nextSeq { - if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil { - s.recvBuf = append(s.recvBuf, data...) - s.nextSeq++ - m.applyOutOfOrder(s, sid, clientID) - m.notifyDataReady(sid) - } - } else if seq > stream.nextSeq { - if len(stream.outOfOrder) < 100 { - stream.outOfOrder[seq] = append([]byte(nil), data...) - } - } -} - -func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream { - stream, exists := m.streams[sid] - if !exists { - if len(m.streams) >= m.maxStreams { - return nil - } - stream = &Stream{ - ID: sid, - ClientID: clientID, - recvBuf: make([]byte, 0), - nextSeq: 0, - outOfOrder: make(map[uint32][]byte), - } - m.streams[sid] = stream - return stream - } - - if stream.ClientID != clientID { - stream.ClientID = clientID - stream.recvBuf = make([]byte, 0) - stream.closed = false - stream.nextSeq = 0 - stream.outOfOrder = make(map[uint32][]byte) - m.bufferCond.Broadcast() - } - return stream -} - -func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) { - for { - nextData, ok := stream.outOfOrder[stream.nextSeq] - if !ok { - break - } - if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil { - return - } - stream.recvBuf = append(stream.recvBuf, nextData...) - delete(stream.outOfOrder, stream.nextSeq) - stream.nextSeq++ - logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1) - } -} - -func (m *Multiplexer) notifyDataReady(sid uint16) { - m.dataReadyMu.Lock() - defer m.dataReadyMu.Unlock() - if ch, ok := m.dataReady[sid]; ok { - select { - case ch <- struct{}{}: - default: - } - } -} - -func (m *Multiplexer) handleControlFrame(control ControlFrame) { - switch control.Type { - case ControlResetClient: - m.ResetClient(control.ClientID) - default: - logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID) - } -} - -// ResetClient closes and removes all streams associated with a client ID. -func (m *Multiplexer) ResetClient(clientID uint32) { - m.mu.Lock() - defer m.mu.Unlock() - - for streamSid, stream := range m.streams { - if stream.ClientID == clientID { - stream.closed = true - delete(m.streams, streamSid) - } - } - m.bufferCond.Broadcast() -} - -func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream { - for { - stream, ok := m.streams[sid] - if !ok || stream.ClientID != clientID || stream.closed { - return nil - } - if len(stream.recvBuf)+need <= m.maxBufferSize { - return stream - } - // Wait for space to become available - m.bufferCond.Wait() - } -} - -// ReadStream retrieves and clears the current receive buffer for a stream. -func (m *Multiplexer) ReadStream(sid uint16) []byte { - m.mu.Lock() - defer m.mu.Unlock() - - stream, exists := m.streams[sid] - if !exists || len(stream.recvBuf) == 0 { - return nil - } - - data := stream.recvBuf - stream.recvBuf = make([]byte, 0) - - // Notify producers that space is now available - m.bufferCond.Broadcast() - - return data -} - -// StreamClosed returns true if the stream is closed or doesn't exist. -func (m *Multiplexer) StreamClosed(sid uint16) bool { - m.mu.RLock() - defer m.mu.RUnlock() - - stream, exists := m.streams[sid] - return !exists || stream.closed -} - -// GetStreams returns a list of all active stream IDs. -func (m *Multiplexer) GetStreams() []uint16 { - m.mu.RLock() - defer m.mu.RUnlock() - - sids := make([]uint16, 0, len(m.streams)) - for sid := range m.streams { - sids = append(sids, sid) - } - return sids -} - -// GetStream returns the Stream object for a given ID. -func (m *Multiplexer) GetStream(sid uint16) *Stream { - m.mu.RLock() - defer m.mu.RUnlock() - return m.streams[sid] -} - -// Reset clears all multiplexer state and closes all streams. -func (m *Multiplexer) Reset() { - m.mu.Lock() - defer m.mu.Unlock() - - for _, stream := range m.streams { - stream.closed = true - } - - m.streams = make(map[uint16]*Stream) - m.nextID = 1 - - m.sendSeqMu.Lock() - m.sendSeq = make(map[uint16]uint32) - m.sendSeqMu.Unlock() - - m.bufferCond.Broadcast() -} - -// UpdateSendFunc updates the function used to transmit raw frames. -func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { - m.mu.Lock() - defer m.mu.Unlock() - - m.onSend = onSend -} - -// WaitForData returns a channel that signals when new data is available for a stream. -func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { - m.dataReadyMu.Lock() - defer m.dataReadyMu.Unlock() - - if _, ok := m.dataReady[sid]; !ok { - m.dataReady[sid] = make(chan struct{}, 1) - } - return m.dataReady[sid] -} - -// CleanupDataChannel removes the data notification channel for a stream. -func (m *Multiplexer) CleanupDataChannel(sid uint16) { - m.dataReadyMu.Lock() - defer m.dataReadyMu.Unlock() - - if ch, ok := m.dataReady[sid]; ok { - close(ch) - delete(m.dataReady, sid) - } -} diff --git a/internal/muxconn/conn.go b/internal/muxconn/conn.go new file mode 100644 index 0000000..b895610 --- /dev/null +++ b/internal/muxconn/conn.go @@ -0,0 +1,119 @@ +// Package muxconn adapts a link.Link into an io.ReadWriteCloser suitable for +// driving a smux session. The wrapper applies AEAD on every wire-bound write +// and inverts it on every received message before exposing the bytes as a +// byte stream. +// +// Link semantics are message-oriented: each Send produces exactly one OnData +// on the peer. smux operates on a pure byte stream (header + payload may be +// glued or split across reads). We bridge by: +// +// - Treating each Push as an opaque chunk appended to an internal byte +// buffer that Read drains in arbitrary slices. +// - Letting smux's sendLoop call Write once per frame; we encrypt and hand +// the whole buffer to the link as a single message. Length boundaries +// are preserved end-to-end by the transport (KCP length-prefix framing +// in vp8channel, native message boundaries in datachannel). +package muxconn + +import ( + "errors" + "io" + "sync" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/link" +) + +// ErrClosed is returned from Read/Write after the conn has been closed. +var ErrClosed = errors.New("muxconn: closed") + +// Conn is an io.ReadWriteCloser over a link.Link with optional AEAD wrapping. +type Conn struct { + ln link.Link + cipher *crypto.Cipher + + mu sync.Mutex + cond *sync.Cond + buf []byte + closed bool +} + +// New wires a Conn over the given link. Push must be set as the link's OnData +// callback before this conn is used. +func New(ln link.Link, cipher *crypto.Cipher) *Conn { + c := &Conn{ln: ln, cipher: cipher} + c.cond = sync.NewCond(&c.mu) + return c +} + +// Push hands an encrypted wire payload (one OnData event) to the conn. +func (c *Conn) Push(ciphertext []byte) { + pt, err := c.cipher.Decrypt(ciphertext) + if err != nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + c.buf = append(c.buf, pt...) + c.cond.Broadcast() +} + +// Read implements io.Reader. Blocks until at least one byte is available. +func (c *Conn) Read(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + for !c.closed && len(c.buf) == 0 { + c.cond.Wait() + } + if len(c.buf) == 0 { + return 0, io.EOF + } + n := copy(p, c.buf) + c.buf = c.buf[n:] + return n, nil +} + +// Write encrypts p and ships it to the link as a single message. Blocks while +// the link signals back-pressure. +func (c *Conn) Write(p []byte) (int, error) { + for { + if c.isClosed() { + return 0, ErrClosed + } + if c.ln.CanSend() { + break + } + time.Sleep(10 * time.Millisecond) + } + + enc, err := c.cipher.Encrypt(p) + if err != nil { + return 0, err + } + if err := c.ln.Send(enc); err != nil { + return 0, err + } + return len(p), nil +} + +// Close unblocks any pending Read with io.EOF. +func (c *Conn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + c.closed = true + c.cond.Broadcast() + return nil +} + +func (c *Conn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} diff --git a/internal/server/server.go b/internal/server/server.go index a56f285..5e42a69 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,44 +11,32 @@ import ( "net" "strconv" "sync" - "sync/atomic" "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" - "github.com/openlibrecommunity/olcrtc/internal/mux" + "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/xtaci/smux" ) var ( // ErrKeySize is returned when the encryption key is not 32 bytes. ErrKeySize = errors.New("key must be 32 bytes") - // ErrKeyStringLength is returned when the encryption key string length is not 32. - ErrKeyStringLength = errors.New("key string length must be 32") // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") - // ErrNoLinks is returned when no links are available. - ErrNoLinks = errors.New("no links available") - // ErrDialProxy is returned when dialing the proxy fails. - ErrDialProxy = errors.New("failed to dial proxy") - // ErrEncryptFailed is returned when encryption fails. - ErrEncryptFailed = errors.New("encrypt failed") ) // Server handles incoming tunnel connections and proxies their traffic. type Server struct { - links []link.Link + ln link.Link cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - streamPumps map[uint16]net.Conn - pumpMu sync.Mutex - linkIdx atomic.Uint32 - activeClients atomic.Int32 + conn *muxconn.Conn + session *smux.Session + sessMu sync.RWMutex wg sync.WaitGroup dnsServer string resolver *net.Resolver @@ -97,25 +85,22 @@ func Run( s := &Server{ cipher: cipher, - connections: make(map[uint16]net.Conn), - streamPumps: make(map[uint16]net.Conn), - links: make([]link.Link, 0), dnsServer: dnsServer, socksProxyAddr: socksProxyAddr, socksProxyPort: socksProxyPort, } - s.setupResolver() - s.setupMux() - const linkCount = 1 - for i := range linkCount { - if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil { - return fmt.Errorf("addLink failed: %w", err) - } + if err := s.bringUpLink( + runCtx, linkName, transportName, carrierName, roomURL, cancel, + videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, + videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, + vp8FPS, vp8BatchSize, + ); err != nil { + return err } - err = s.runLoop(runCtx) + err = s.serve(runCtx) s.shutdown() s.wg.Wait() @@ -136,12 +121,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) { return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key)) } - keyStr := string(key) - if len(keyStr) != 32 { - return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr)) - } - - cipher, err := crypto.NewCipher(keyStr) + cipher, err := crypto.NewCipher(string(key)) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } @@ -158,51 +138,30 @@ func (s *Server) setupResolver() { } } -func (s *Server) setupMux() { - s.mux = mux.New(0, func(frame []byte) error { - for { - canSend := true - for _, ln := range s.links { - if !ln.CanSend() { - canSend = false - break - } - } - if canSend { - break - } - time.Sleep(10 * time.Millisecond) - } - - encrypted, err := s.cipher.Encrypt(frame) - if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) - } - if len(s.links) == 0 { - return ErrNoLinks - } - idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec - return s.links[idx].Send(encrypted) - }) +// smuxConfig mirrors the client side. Both peers must agree on Version and +// MaxFrameSize. +func smuxConfig() *smux.Config { + cfg := smux.DefaultConfig() + cfg.Version = 2 + cfg.MaxFrameSize = 32768 + cfg.MaxReceiveBuffer = 16 * 1024 * 1024 + cfg.MaxStreamBuffer = 1024 * 1024 + cfg.KeepAliveInterval = 10 * time.Second + cfg.KeepAliveTimeout = 60 * time.Second + return cfg } -func (s *Server) addLink( +func (s *Server) bringUpLink( ctx context.Context, - linkName, - transportName, - carrierName, - roomURL string, - linkID int, + linkName, transportName, carrierName, roomURL string, cancel context.CancelFunc, videoWidth, videoHeight, videoFPS int, videoBitrate, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, - videoTileModule int, - videoTileRS int, - vp8FPS int, - vp8BatchSize int, + videoTileModule, videoTileRS int, + vp8FPS, vp8BatchSize int, ) error { ln, err := link.New(ctx, linkName, link.Config{ Transport: transportName, @@ -229,22 +188,21 @@ func (s *Server) addLink( if err != nil { return fmt.Errorf("failed to create link: %w", err) } + s.ln = ln ln.SetEndedCallback(func(reason string) { - logger.Infof("Server link %d reported conference end: %s", linkID, reason) + logger.Infof("Server link reported conference end: %s", reason) cancel() }) - s.links = append(s.links, ln) + ln.SetReconnectCallback(func() { s.handleReconnect() }) - ln.SetReconnectCallback(func() { - s.handleLinkReconnect(linkID) - }) - - logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName) + logger.Infof("Connecting link via %s/%s/%s...", linkName, transportName, carrierName) if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } - logger.Infof("Link %d connected", linkID) + logger.Infof("Link connected") + + s.installSession() s.wg.Add(1) go func() { @@ -254,30 +212,195 @@ func (s *Server) addLink( return nil } -func (s *Server) handleLinkReconnect(linkID int) { - logger.Infof("link %d reconnect event", linkID) - - s.connMu.Lock() - for sid, conn := range s.connections { - if conn != nil { - _ = conn.Close() - } - delete(s.connections, sid) +func (s *Server) installSession() { + conn := muxconn.New(s.ln, s.cipher) + sess, err := smux.Server(conn, smuxConfig()) + if err != nil { + logger.Warnf("smux server init failed: %v", err) + return } - s.connMu.Unlock() + s.sessMu.Lock() + s.conn = conn + s.session = sess + s.sessMu.Unlock() +} - s.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := s.cipher.Encrypt(frame) +func (s *Server) handleReconnect() { + logger.Infof("server link reconnect — tearing down smux session") + s.sessMu.Lock() + if s.session != nil { + _ = s.session.Close() + s.session = nil + } + if s.conn != nil { + _ = s.conn.Close() + s.conn = nil + } + s.sessMu.Unlock() + s.installSession() +} + +func (s *Server) onData(data []byte) { + s.sessMu.RLock() + conn := s.conn + s.sessMu.RUnlock() + if conn != nil { + conn.Push(data) + } +} + +// serve drives the smux Accept loop, spawning a tunnel per inbound stream. +// The loop tolerates session bounces (reconnects) by waiting until a fresh +// session is installed instead of terminating the server. +func (s *Server) serve(ctx context.Context) error { + for { + if ctx.Err() != nil { + return nil + } + + s.sessMu.RLock() + sess := s.session + s.sessMu.RUnlock() + if sess == nil { + select { + case <-ctx.Done(): + return nil + case <-time.After(50 * time.Millisecond): + continue + } + } + + stream, err := sess.AcceptStream() if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + // Session is torn down (reconnect or close). If we're shutting + // down, exit; otherwise wait for a new session and retry. + if ctx.Err() != nil { + return nil + } + logger.Infof("AcceptStream returned %v — waiting for new session", err) + time.Sleep(100 * time.Millisecond) + continue } - if len(s.links) == 0 { - return ErrNoLinks + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.handleStream(ctx, stream) + }() + } +} + +func (s *Server) shutdown() { + s.sessMu.Lock() + if s.session != nil { + _ = s.session.Close() + } + if s.conn != nil { + _ = s.conn.Close() + } + s.sessMu.Unlock() + if s.ln != nil { + _ = s.ln.Close() + } +} + +func (s *Server) handleStream(_ context.Context, stream *smux.Stream) { + defer stream.Close() + + // Read the connect JSON. The client writes the whole JSON in one + // stream.Write so it usually arrives intact; tolerate fragmentation + // by reading incrementally up to a sane cap. + const maxConnReq = 4096 + header := make([]byte, 0, 256) + tmp := make([]byte, 256) + _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) + for { + n, err := stream.Read(tmp) + if n > 0 { + header = append(header, tmp[:n]...) + if req, ok := parseConnectRequest(header); ok { + _ = stream.SetReadDeadline(time.Time{}) + s.dispatch(stream, req) + return + } } - idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec - return s.links[idx].Send(encrypted) - }) - s.mux.Reset() + if err != nil { + return + } + if len(header) > maxConnReq { + return + } + } +} + +func parseConnectRequest(buf []byte) (ConnectRequest, bool) { + var req ConnectRequest + if err := json.Unmarshal(buf, &req); err != nil { + return req, false + } + if req.Cmd != "connect" { + return req, false + } + return req, true +} + +func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { + addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) + logger.Infof("sid=%d connect %s", stream.ID(), addr) + + dialStart := time.Now() + conn, err := s.dial(req) + dialElapsed := time.Since(dialStart) + + if err != nil { + logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err) + return + } + defer conn.Close() + + logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed) + + if _, err := stream.Write([]byte{0x00}); err != nil { + return + } + + go func() { + _, _ = io.Copy(stream, conn) + _ = stream.Close() + }() + _, _ = io.Copy(conn, stream) +} + +func (s *Server) dial(req ConnectRequest) (net.Conn, error) { + addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) + if s.socksProxyAddr == "" { + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + Resolver: s.resolver, + } + conn, err := dialer.Dial("tcp4", addr) + if err != nil { + return nil, fmt.Errorf("dial failed: %w", err) + } + return conn, nil + } + + proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort)) + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + conn, err := dialer.Dial("tcp4", proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial proxy: %w", err) + } + + if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil } func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { @@ -318,336 +441,3 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) return nil } - -func (s *Server) onData(data []byte) { - plaintext, err := s.cipher.Decrypt(data) - if err != nil { - logger.Debugf("Decrypt error: %v", err) - return - } - - if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient { - logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID) - s.closeClientConnections(control.ClientID) - } - - s.mux.HandleFrame(plaintext) -} - -func (s *Server) closeClientConnections(clientID uint32) { - s.connMu.Lock() - defer s.connMu.Unlock() - - for streamSid, conn := range s.connections { - stream := s.mux.GetStream(streamSid) - if stream != nil && stream.ClientID == clientID { - if conn != nil { - _ = conn.Close() - } - delete(s.connections, streamSid) - } - } -} - -func (s *Server) runLoop(ctx context.Context) error { - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return nil - case <-ticker.C: - s.processMuxStreams(ctx) - } - } -} - -func (s *Server) shutdown() { - s.connMu.Lock() - for _, conn := range s.connections { - if conn != nil { - _ = conn.Close() - } - } - s.connMu.Unlock() - - s.pumpMu.Lock() - for _, conn := range s.streamPumps { - if conn != nil { - _ = conn.Close() - } - } - s.pumpMu.Unlock() - - for i, tr := range s.links { - logger.Infof("closing link %d", i) - _ = tr.Close() - } -} - -func (s *Server) processMuxStreams(ctx context.Context) { - sids := s.mux.GetStreams() - for _, sid := range sids { - if s.mux.StreamClosed(sid) { - s.closeStreamConnection(sid) - continue - } - - if s.hasConnection(sid) { - continue - } - - data := s.mux.ReadStream(sid) - if len(data) == 0 { - continue - } - - var req ConnectRequest - if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { - logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port) - s.closeStreamConnection(sid) - go s.handleConnect(ctx, sid, req) - } - } -} - -func (s *Server) hasConnection(sid uint16) bool { - s.connMu.RLock() - defer s.connMu.RUnlock() - return s.connections[sid] != nil -} - -func (s *Server) closeStreamConnection(sid uint16) { - s.connMu.Lock() - conn := s.connections[sid] - if conn != nil { - _ = conn.Close() - delete(s.connections, sid) - } - s.connMu.Unlock() -} - -func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) { - s.connMu.Lock() - conn := s.connections[sid] - if conn == expected { - _ = conn.Close() - delete(s.connections, sid) - } - s.connMu.Unlock() -} - -func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool { - s.pumpMu.Lock() - defer s.pumpMu.Unlock() - if current := s.streamPumps[sid]; current == conn { - return false - } else if current != nil { - _ = current.Close() - } - s.streamPumps[sid] = conn - return true -} - -func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) { - s.pumpMu.Lock() - if s.streamPumps[sid] == conn { - delete(s.streamPumps, sid) - } - s.pumpMu.Unlock() -} - -func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) { - addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) - - s.closeStreamConnection(sid) - - dialStart := time.Now() - conn, err := s.dial(req) - dialElapsed := time.Since(dialStart) - - if err != nil { - logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err) - _ = s.mux.CloseStream(sid) - return - } - - s.connMu.Lock() - s.connections[sid] = conn - s.connMu.Unlock() - - logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed) - - s.activeClients.Add(1) - _ = s.mux.SendData(sid, []byte{0x00}) - s.startStreamPump(ctx, sid, conn) - - go s.pumpToMux(sid, conn) -} - -func (s *Server) dial(req ConnectRequest) (net.Conn, error) { - addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) - if s.socksProxyAddr == "" { - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - Resolver: s.resolver, - } - conn, err := dialer.Dial("tcp4", addr) - if err != nil { - return nil, fmt.Errorf("dial failed: %w", err) - } - return conn, nil - } - - proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort)) - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - } - conn, err := dialer.Dial("tcp4", proxyAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial proxy: %w", err) - } - - if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { - _ = conn.Close() - return nil, err - } - return conn, nil -} - -func (s *Server) pumpToMux(sid uint16, conn net.Conn) { - defer func() { - s.activeClients.Add(-1) - _ = s.mux.CloseStream(sid) - s.connMu.Lock() - delete(s.connections, sid) - s.connMu.Unlock() - }() - - // Decoupling queue: Read goroutine pushes here, sender goroutine drains - // to mux.SendData. Without this, slow channel back-pressure stalls the - // upstream Read which can cause TCP receive window to collapse to zero - // and effectively wedge the connection (peer stops sending and never - // resumes even though our channel is healthy). - type chunk struct{ data []byte } - queue := make(chan chunk, 64) - doneSender := make(chan struct{}) - - go func() { - defer close(doneSender) - for c := range queue { - for !s.canSendData() { - time.Sleep(20 * time.Millisecond) - } - if err := s.mux.SendData(sid, c.data); err != nil { - return - } - } - }() - - // queueHasSpace blocks until the decoupling queue has room or the - // sender goroutine has exited. We wait here *before* arming the - // upstream read deadline so that channel back-pressure isn't billed - // to the socket as idle time and doesn't trip a spurious i/o timeout. - queueHasSpace := func() bool { - for { - if len(queue) < cap(queue) { - return true - } - select { - case <-doneSender: - return false - case <-time.After(10 * time.Millisecond): - } - } - } - - buf := make([]byte, 16384) - - // Idle timeout for genuinely dead upstreams. Only armed when we are - // actively waiting on the socket (queue has space). During internal - // back-pressure the deadline is not in effect, so flow-control pauses - // don't get mis-classified as remote death. - const idleReadTimeout = 60 * time.Second - - for { - if !queueHasSpace() { - close(queue) - <-doneSender - return - } - - // Arm the deadline only when we actually want bytes from the peer. - _ = conn.SetReadDeadline(time.Now().Add(idleReadTimeout)) - - n, err := conn.Read(buf) - if err != nil { - close(queue) - <-doneSender - return - } - - // Clear the deadline so it doesn't fire while we are blocked in - // queueHasSpace() on the next iteration (back-pressure path). - _ = conn.SetReadDeadline(time.Time{}) - - // Copy because buf is reused on next Read. - c := make([]byte, n) - copy(c, buf[:n]) - - // Guaranteed non-blocking thanks to queueHasSpace() above (we are - // the sole producer); the blocking fallback is just defensive. - select { - case queue <- chunk{data: c}: - default: - queue <- chunk{data: c} - } - } -} - -func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { - if !s.markStreamPump(sid, conn) { - return - } - - s.wg.Add(1) - go func() { - defer s.wg.Done() - defer s.unmarkStreamPump(sid, conn) - - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - data := s.mux.ReadStream(sid) - if len(data) > 0 { - if _, err := conn.Write(data); err != nil { - _ = s.mux.CloseStream(sid) - s.closeStreamConnectionIfCurrent(sid, conn) - return - } - } - if s.mux.StreamClosed(sid) { - s.closeStreamConnectionIfCurrent(sid, conn) - return - } - } - } - }() -} - -func (s *Server) canSendData() bool { - for _, tr := range s.links { - if !tr.CanSend() { - return false - } - } - return true -}