// Package server implements the olcrtc tunnel server logic. package server import ( "context" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "strconv" "sync" "sync/atomic" "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" ) var ( // ErrKeySize is returned when the encryption key is not 32 bytes. ErrKeySize = errors.New("key must be 32 bytes") // ErrKeyStringLength is returned when the encryption key string length is not 32. ErrKeyStringLength = errors.New("key string length must be 32") // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") // ErrNoLinks is returned when no links are available. ErrNoLinks = errors.New("no links available") // ErrDialProxy is returned when dialing the proxy fails. ErrDialProxy = errors.New("failed to dial proxy") // ErrEncryptFailed is returned when encryption fails. ErrEncryptFailed = errors.New("encrypt failed") ) // Server handles incoming tunnel connections and proxies their traffic. type Server struct { links []link.Link cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn connMu sync.RWMutex streamPumps map[uint16]net.Conn pumpMu sync.Mutex linkIdx atomic.Uint32 activeClients atomic.Int32 wg sync.WaitGroup dnsServer string resolver *net.Resolver socksProxyAddr string socksProxyPort int } // ConnectRequest is a message from the client to establish a new connection. type ConnectRequest struct { Cmd string `json:"cmd"` Addr string `json:"addr"` Port int `json:"port"` } // Run starts the server with the specified parameters. func Run( ctx context.Context, linkName, transportName, carrierName, roomURL, keyHex string, dnsServer, socksProxyAddr string, socksProxyPort int, videoWidth int, videoHeight int, videoFPS int, videoBitrate string, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, videoTileModule int, videoTileRS int, vp8FPS int, vp8BatchSize int, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() cipher, err := setupCipher(keyHex) if err != nil { return fmt.Errorf("setupCipher failed: %w", err) } s := &Server{ cipher: cipher, connections: make(map[uint16]net.Conn), streamPumps: make(map[uint16]net.Conn), links: make([]link.Link, 0), dnsServer: dnsServer, socksProxyAddr: socksProxyAddr, socksProxyPort: socksProxyPort, } s.setupResolver() s.setupMux() const linkCount = 1 for i := range linkCount { if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil { return fmt.Errorf("addLink failed: %w", err) } } err = s.runLoop(runCtx) s.shutdown() s.wg.Wait() return err } func setupCipher(keyHex string) (*crypto.Cipher, error) { if keyHex == "" { return nil, errors.New("key required (use -key )") } key, err := hex.DecodeString(keyHex) if err != nil { return nil, fmt.Errorf("failed to decode key: %w", err) } if len(key) != 32 { return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key)) } keyStr := string(key) if len(keyStr) != 32 { return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr)) } cipher, err := crypto.NewCipher(keyStr) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } return cipher, nil } func (s *Server) setupResolver() { s.resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { d := net.Dialer{Timeout: 3 * time.Second} return d.DialContext(ctx, network, s.dnsServer) }, } } func (s *Server) setupMux() { s.mux = mux.New(0, func(frame []byte) error { for { canSend := true for _, ln := range s.links { if !ln.CanSend() { canSend = false break } } if canSend { break } time.Sleep(10 * time.Millisecond) } encrypted, err := s.cipher.Encrypt(frame) if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } if len(s.links) == 0 { return ErrNoLinks } idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec return s.links[idx].Send(encrypted) }) } func (s *Server) addLink( ctx context.Context, linkName, transportName, carrierName, roomURL string, linkID int, cancel context.CancelFunc, videoWidth, videoHeight, videoFPS int, videoBitrate, videoHW string, videoQRSize int, videoQRRecovery string, videoCodec string, videoTileModule int, videoTileRS int, vp8FPS int, vp8BatchSize int, ) error { ln, err := link.New(ctx, linkName, link.Config{ Transport: transportName, Carrier: carrierName, RoomURL: roomURL, Name: names.Generate(), OnData: s.onData, DNSServer: s.dnsServer, ProxyAddr: s.socksProxyAddr, ProxyPort: s.socksProxyPort, VideoWidth: videoWidth, VideoHeight: videoHeight, VideoFPS: videoFPS, VideoBitrate: videoBitrate, VideoHW: videoHW, VideoQRSize: videoQRSize, VideoQRRecovery: videoQRRecovery, VideoCodec: videoCodec, VideoTileModule: videoTileModule, VideoTileRS: videoTileRS, VP8FPS: vp8FPS, VP8BatchSize: vp8BatchSize, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) } ln.SetEndedCallback(func(reason string) { logger.Infof("Server link %d reported conference end: %s", linkID, reason) cancel() }) s.links = append(s.links, ln) ln.SetReconnectCallback(func() { s.handleLinkReconnect(linkID) }) logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName) if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } logger.Infof("Link %d connected", linkID) s.wg.Add(1) go func() { defer s.wg.Done() ln.WatchConnection(ctx) }() return nil } func (s *Server) handleLinkReconnect(linkID int) { logger.Infof("link %d reconnect event", linkID) s.connMu.Lock() for sid, conn := range s.connections { if conn != nil { _ = conn.Close() } delete(s.connections, sid) } s.connMu.Unlock() s.mux.UpdateSendFunc(func(frame []byte) error { encrypted, err := s.cipher.Encrypt(frame) if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } if len(s.links) == 0 { return ErrNoLinks } idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec return s.links[idx].Send(encrypted) }) s.mux.Reset() } func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { if _, err := conn.Write([]byte{5, 1, 0}); err != nil { return fmt.Errorf("failed to write socks5 auth: %w", err) } resp := make([]byte, 2) if _, err := io.ReadFull(conn, resp); err != nil { return fmt.Errorf("failed to read socks5 auth resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { return ErrSocks5AuthFailed } addrLen := len(targetAddr) if addrLen > 255 { addrLen = 255 targetAddr = targetAddr[:255] } req := make([]byte, 0, 7+addrLen) req = append(req, 5, 1, 0, 3, byte(addrLen)) req = append(req, []byte(targetAddr)...) req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec if _, err := conn.Write(req); err != nil { return fmt.Errorf("failed to write socks5 connect req: %w", err) } resp = make([]byte, 10) if _, err := io.ReadFull(conn, resp); err != nil { return fmt.Errorf("failed to read socks5 connect resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1]) } return nil } func (s *Server) onData(data []byte) { plaintext, err := s.cipher.Decrypt(data) if err != nil { logger.Debugf("Decrypt error: %v", err) return } if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient { logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID) s.closeClientConnections(control.ClientID) } s.mux.HandleFrame(plaintext) } func (s *Server) closeClientConnections(clientID uint32) { s.connMu.Lock() defer s.connMu.Unlock() for streamSid, conn := range s.connections { stream := s.mux.GetStream(streamSid) if stream != nil && stream.ClientID == clientID { if conn != nil { _ = conn.Close() } delete(s.connections, streamSid) } } } func (s *Server) runLoop(ctx context.Context) error { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): return nil case <-ticker.C: s.processMuxStreams(ctx) } } } func (s *Server) shutdown() { s.connMu.Lock() for _, conn := range s.connections { if conn != nil { _ = conn.Close() } } s.connMu.Unlock() s.pumpMu.Lock() for _, conn := range s.streamPumps { if conn != nil { _ = conn.Close() } } s.pumpMu.Unlock() for i, tr := range s.links { logger.Infof("closing link %d", i) _ = tr.Close() } } func (s *Server) processMuxStreams(ctx context.Context) { sids := s.mux.GetStreams() for _, sid := range sids { if s.mux.StreamClosed(sid) { s.closeStreamConnection(sid) continue } if s.hasConnection(sid) { continue } data := s.mux.ReadStream(sid) if len(data) == 0 { continue } var req ConnectRequest if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" { logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port) s.closeStreamConnection(sid) go s.handleConnect(ctx, sid, req) } } } func (s *Server) hasConnection(sid uint16) bool { s.connMu.RLock() defer s.connMu.RUnlock() return s.connections[sid] != nil } func (s *Server) closeStreamConnection(sid uint16) { s.connMu.Lock() conn := s.connections[sid] if conn != nil { _ = conn.Close() delete(s.connections, sid) } s.connMu.Unlock() } func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) { s.connMu.Lock() conn := s.connections[sid] if conn == expected { _ = conn.Close() delete(s.connections, sid) } s.connMu.Unlock() } func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool { s.pumpMu.Lock() defer s.pumpMu.Unlock() if current := s.streamPumps[sid]; current == conn { return false } else if current != nil { _ = current.Close() } s.streamPumps[sid] = conn return true } func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) { s.pumpMu.Lock() if s.streamPumps[sid] == conn { delete(s.streamPumps, sid) } s.pumpMu.Unlock() } func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) { addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) s.closeStreamConnection(sid) dialStart := time.Now() conn, err := s.dial(req) dialElapsed := time.Since(dialStart) if err != nil { logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err) _ = s.mux.CloseStream(sid) return } s.connMu.Lock() s.connections[sid] = conn s.connMu.Unlock() logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed) s.activeClients.Add(1) _ = s.mux.SendData(sid, []byte{0x00}) s.startStreamPump(ctx, sid, conn) go s.pumpToMux(sid, conn) } func (s *Server) dial(req ConnectRequest) (net.Conn, error) { addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) if s.socksProxyAddr == "" { dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, Resolver: s.resolver, } conn, err := dialer.Dial("tcp4", addr) if err != nil { return nil, fmt.Errorf("dial failed: %w", err) } return conn, nil } proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort)) dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, } conn, err := dialer.Dial("tcp4", proxyAddr) if err != nil { return nil, fmt.Errorf("failed to dial proxy: %w", err) } if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { _ = conn.Close() return nil, err } return conn, nil } func (s *Server) pumpToMux(sid uint16, conn net.Conn) { defer func() { s.activeClients.Add(-1) _ = s.mux.CloseStream(sid) s.connMu.Lock() delete(s.connections, sid) s.connMu.Unlock() }() // Decoupling queue: Read goroutine pushes here, sender goroutine drains // to mux.SendData. Without this, slow channel back-pressure stalls the // upstream Read which can cause TCP receive window to collapse to zero // and effectively wedge the connection (peer stops sending and never // resumes even though our channel is healthy). type chunk struct{ data []byte } queue := make(chan chunk, 64) doneSender := make(chan struct{}) go func() { defer close(doneSender) for c := range queue { for !s.canSendData() { time.Sleep(20 * time.Millisecond) } if err := s.mux.SendData(sid, c.data); err != nil { return } } }() // queueHasSpace blocks until the decoupling queue has room or the // sender goroutine has exited. We wait here *before* arming the // upstream read deadline so that channel back-pressure isn't billed // to the socket as idle time and doesn't trip a spurious i/o timeout. queueHasSpace := func() bool { for { if len(queue) < cap(queue) { return true } select { case <-doneSender: return false case <-time.After(10 * time.Millisecond): } } } buf := make([]byte, 16384) // Idle timeout for genuinely dead upstreams. Only armed when we are // actively waiting on the socket (queue has space). During internal // back-pressure the deadline is not in effect, so flow-control pauses // don't get mis-classified as remote death. const idleReadTimeout = 60 * time.Second for { if !queueHasSpace() { close(queue) <-doneSender return } // Arm the deadline only when we actually want bytes from the peer. _ = conn.SetReadDeadline(time.Now().Add(idleReadTimeout)) n, err := conn.Read(buf) if err != nil { close(queue) <-doneSender return } // Clear the deadline so it doesn't fire while we are blocked in // queueHasSpace() on the next iteration (back-pressure path). _ = conn.SetReadDeadline(time.Time{}) // Copy because buf is reused on next Read. c := make([]byte, n) copy(c, buf[:n]) // Guaranteed non-blocking thanks to queueHasSpace() above (we are // the sole producer); the blocking fallback is just defensive. select { case queue <- chunk{data: c}: default: queue <- chunk{data: c} } } } func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) { if !s.markStreamPump(sid, conn) { return } s.wg.Add(1) go func() { defer s.wg.Done() defer s.unmarkStreamPump(sid, conn) ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: data := s.mux.ReadStream(sid) if len(data) > 0 { if _, err := conn.Write(data); err != nil { _ = s.mux.CloseStream(sid) s.closeStreamConnectionIfCurrent(sid, conn) return } } if s.mux.StreamClosed(sid) { s.closeStreamConnectionIfCurrent(sid, conn) return } } } }() } func (s *Server) canSendData() bool { for _, tr := range s.links { if !tr.CanSend() { return false } } return true }