// Package server implements the olcrtc tunnel server logic. package server import ( "context" "encoding/json" "errors" "fmt" "io" "net" "strconv" "sync" "time" "github.com/google/uuid" "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" "github.com/openlibrecommunity/olcrtc/internal/runtime" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) const connectCommand = "connect" var ( // ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with // pre-runtime callers that errors.Is-checked it. ErrKeyRequired = runtime.ErrKeyRequired // ErrKeySize re-exports runtime.ErrKeySize for the same reason. ErrKeySize = runtime.ErrKeySize // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") ) // SessionOpenFunc is called after a successful handshake, before the server // accepts tunnel streams on that session. type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any) // SessionCloseFunc is called when a session is torn down. Possible reasons: // "reconnect" (carrier dropped and was reestablished), "closed" (graceful // shutdown or ctx cancel). type SessionCloseFunc func(sessionID, reason string) // TrafficFunc is called once per tunnel stream, after the copy loops finish. // bytesIn counts client→target bytes; bytesOut counts target→client bytes. type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64) // HealthFunc is called when the server control health snapshot changes. type HealthFunc func(control.Status) // Server handles incoming tunnel connections and proxies their traffic. type Server struct { ln transport.Transport peerLn transport.PeerTransport cipher *crypto.Cipher conn *muxconn.Conn session *smux.Session controlStrm *smux.Stream controlStop context.CancelFunc sessMu sync.RWMutex peerSessions map[string]*peerSession reinstallMu sync.Mutex wg sync.WaitGroup authHook handshake.AuthFunc onOpen SessionOpenFunc onClose SessionCloseFunc onTraffic TrafficFunc deviceID string sessionID string dnsServer string resolver *net.Resolver socksProxyAddr string socksProxyPort int liveness control.Config health *runtime.HealthTracker done chan struct{} doneOnce sync.Once } type peerSession struct { peerID string conn *muxconn.Conn session *smux.Session controlStrm *smux.Stream controlStop context.CancelFunc sessionID string deviceID string } // ConnectRequest is a message from the client to establish a new connection. type ConnectRequest struct { Cmd string `json:"cmd"` Addr string `json:"addr"` Port int `json:"port"` } // Config holds runtime configuration for [Run]. type Config struct { Transport string Carrier string RoomURL string ChannelID string KeyHex string DNSServer string SOCKSProxyAddr string SOCKSProxyPort int TransportOptions transport.Options Engine string URL string Token string Liveness control.Config Traffic transport.TrafficConfig // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. AuthHook handshake.AuthFunc // OnSessionOpen fires after a successful handshake. Nil means no-op. OnSessionOpen SessionOpenFunc // OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op. OnSessionClose SessionCloseFunc // OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op. OnTraffic TrafficFunc // OnHealth fires when liveness/reconnect status changes. Nil means no-op. OnHealth HealthFunc } // Run starts the server with the given configuration. func Run(ctx context.Context, cfg Config) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() cipher, err := setupCipher(cfg.KeyHex) if err != nil { return fmt.Errorf("setupCipher failed: %w", err) } hook := cfg.AuthHook if hook == nil { hook = defaultAuthHook } onOpen := cfg.OnSessionOpen if onOpen == nil { onOpen = func(string, string, map[string]any) {} } onClose := cfg.OnSessionClose if onClose == nil { onClose = func(string, string) {} } onTraffic := cfg.OnTraffic if onTraffic == nil { onTraffic = func(string, string, uint64, uint64) {} } s := &Server{ cipher: cipher, authHook: hook, onOpen: onOpen, onClose: onClose, onTraffic: onTraffic, dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, liveness: cfg.Liveness, health: runtime.NewHealthTracker(cfg.OnHealth), peerSessions: make(map[string]*peerSession), done: make(chan struct{}), } s.setupResolver() // Register shutdown BEFORE bringUpLink so a partial setup (e.g. // link.New succeeded but ln.Connect timed out) still tears the // link down and sends MUC presence-unavailable. Without this, an // early bringUpLink error returns straight to the caller and the // already-joined MUC presence stays behind as a ghost participant // for subsequent tests against the same room. shutdown is // idempotent and safe to call before s.serve runs. defer func() { s.shutdown() s.wg.Wait() }() if err := s.bringUpLink(runCtx, cfg, cancel); err != nil { return err } go func() { <-runCtx.Done() s.closeSession() }() s.serve(runCtx) return nil } func setupCipher(keyHex string) (*crypto.Cipher, error) { cipher, err := runtime.SetupCipher(keyHex) if err != nil { return nil, fmt.Errorf("server: %w", err) } return cipher, nil } func (s *Server) setupResolver() { s.resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { d := net.Dialer{Timeout: 3 * time.Second} return d.DialContext(ctx, network, s.dnsServer) }, } } func smuxConfig(maxWirePayload int) *smux.Config { return runtime.SmuxConfig(maxWirePayload) } func linkMaxPayload(tr transport.Transport) int { return runtime.MaxPayload(tr) } func (s *Server) bringUpLink( ctx context.Context, cfg Config, cancel context.CancelFunc, ) error { ln, err := transport.New(ctx, cfg.Transport, transport.Config{ Carrier: cfg.Carrier, RoomURL: cfg.RoomURL, Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, ChannelID: cfg.ChannelID, DeviceID: "", Name: names.Generate(), OnData: s.onData, OnPeerData: s.onPeerData, DNSServer: s.dnsServer, ProxyAddr: s.socksProxyAddr, ProxyPort: s.socksProxyPort, Options: cfg.TransportOptions, Traffic: cfg.Traffic, }) if err != nil { return fmt.Errorf("failed to create transport: %w", err) } s.ln = ln if peerLn, ok := ln.(transport.PeerTransport); ok && peerLn.SupportsPeerRouting() { s.peerLn = peerLn } ln.SetEndedCallback(func(reason string) { logger.Infof("Server link reported conference end: %s", reason) cancel() }) ln.SetShouldReconnect(func() bool { return ctx.Err() == nil }) ln.SetReconnectCallback(func() { if ctx.Err() != nil { return } s.handleReconnect() }) logger.Infof("Connecting transport=%s carrier=%s ...", cfg.Transport, cfg.Carrier) if s.peerLn == nil { s.installSession() } if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect link: %w", err) } logger.Infof("Link connected") s.wg.Add(1) go func() { defer s.wg.Done() ln.WatchConnection(ctx) }() return nil } func (s *Server) installSession() { conn := muxconn.New(s.ln, s.cipher) sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) return } s.sessMu.Lock() s.conn = conn s.session = sess s.sessMu.Unlock() } func (s *Server) handleReconnect() { s.recordReconnect() logger.Infof("server reconnect reason=carrier - tearing down smux session") s.sessMu.RLock() current := s.session s.sessMu.RUnlock() s.reinstallSession(current) } func (s *Server) reinstallSession(dead *smux.Session) { s.reinstallMu.Lock() defer s.reinstallMu.Unlock() // Pre-build the replacement so we can swap atomically below. newConn := muxconn.New(s.ln, s.cipher) newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) _ = newConn.Close() return } s.sessMu.Lock() if s.session != dead { // Someone else already reinstalled — discard our build. s.sessMu.Unlock() _ = newSess.Close() _ = newConn.Close() return } oldSess := s.session oldConn := s.conn oldControl := s.controlStrm oldControlStop := s.controlStop oldSID := s.sessionID s.session = newSess s.conn = newConn s.controlStrm = nil s.controlStop = nil s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() if oldControlStop != nil { oldControlStop() } if oldSess != nil { _ = oldSess.Close() } if oldConn != nil { _ = oldConn.Close() } if oldControl != nil { _ = oldControl.Close() } if oldSID != "" { s.onClose(oldSID, "reconnect") } } func (s *Server) closeSession() { s.sessMu.Lock() sess := s.session conn := s.conn control := s.controlStrm controlStop := s.controlStop peers := s.peerSessions s.peerSessions = make(map[string]*peerSession) s.session = nil s.conn = nil s.controlStrm = nil s.controlStop = nil oldSID := s.sessionID s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() if controlStop != nil { controlStop() } notifyControlClose(control) if sess != nil { _ = sess.Close() } if conn != nil { _ = conn.Close() } if oldSID != "" { s.onClose(oldSID, "closed") } for _, ps := range peers { s.closePeerSession(ps, "closed") } } func (s *Server) removePeerSession(peerID, reason string) { s.sessMu.Lock() ps := s.peerSessions[peerID] delete(s.peerSessions, peerID) s.sessMu.Unlock() if ps != nil { s.closePeerSession(ps, reason) } } func (s *Server) closePeerSession(ps *peerSession, reason string) { if ps.controlStop != nil { ps.controlStop() } notifyControlClose(ps.controlStrm) if ps.session != nil { _ = ps.session.Close() } if ps.conn != nil { _ = ps.conn.Close() } if ps.controlStrm != nil { _ = ps.controlStrm.Close() } if ps.sessionID != "" { s.onClose(ps.sessionID, reason) } } func notifyControlClose(stream *smux.Stream) { if stream == nil { return } _ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second)) if err := control.SendClose(stream); err == nil { time.Sleep(200 * time.Millisecond) } _ = stream.SetWriteDeadline(time.Time{}) _ = stream.CloseWrite() } func (s *Server) onData(data []byte) { s.sessMu.RLock() conn := s.conn s.sessMu.RUnlock() if conn != nil { conn.Push(data) } } func (s *Server) onPeerData(peerID string, data []byte) { ps := s.getPeerSession(peerID) if ps == nil { return } ps.conn.Push(data) } func (s *Server) getPeerSession(peerID string) *peerSession { if peerID == "" || s.peerLn == nil { return nil } s.sessMu.Lock() if ps := s.peerSessions[peerID]; ps != nil { s.sessMu.Unlock() return ps } conn := muxconn.NewPeer(s.peerLn, s.cipher, peerID) sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { s.sessMu.Unlock() logger.Warnf("smux server init failed for peer %s: %v", peerID, err) _ = conn.Close() return nil } ps := &peerSession{peerID: peerID, conn: conn, session: sess} s.peerSessions[peerID] = ps s.sessMu.Unlock() s.wg.Add(1) go func() { defer s.wg.Done() s.servePeer(ps) }() return ps } // serve drives the smux Accept loop. The first accepted stream on a given // smux session is the control stream — the handshake runs there. Subsequent // streams are tunnel streams and proxy traffic. func (s *Server) serve(ctx context.Context) { if s.peerLn != nil { <-ctx.Done() return } s.serveSingle(ctx) } func (s *Server) serveSingle(ctx context.Context) { for { if contextDone(ctx) { return } s.sessMu.RLock() sess := s.session s.sessMu.RUnlock() if sess == nil { select { case <-ctx.Done(): return case <-time.After(50 * time.Millisecond): continue } } if !s.handshakeReady() { if !s.acceptHandshake(ctx, sess) { continue } } stream, err := sess.AcceptStream() if err != nil { if contextDone(ctx) { return } logger.Debugf("AcceptStream returned %v - reinstalling session", err) s.reinstallSession(sess) continue } s.wg.Add(1) go func() { defer s.wg.Done() s.handleStream(ctx, stream, s.currentSessionID()) }() } } func (s *Server) currentSessionID() string { s.sessMu.RLock() defer s.sessMu.RUnlock() return s.sessionID } func contextDone(ctx context.Context) bool { select { case <-ctx.Done(): return true default: return false } } // handshakeReady reports whether the current session has completed its // handshake. The session is reset on reconnect, so this is recomputed. func (s *Server) handshakeReady() bool { s.sessMu.RLock() defer s.sessMu.RUnlock() return s.sessionID != "" } func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { stream, err := sess.AcceptStream() if err != nil { select { case <-ctx.Done(): return false default: } logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err) s.resetLinkPeer() s.reinstallSession(sess) return false } _ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout)) hello, sid, err := handshake.Server(stream, s.authHook) _ = stream.SetDeadline(time.Time{}) if err != nil { logger.Warnf("handshake failed: %v", err) _ = stream.Close() s.resetLinkPeer() s.reinstallSession(sess) return false } s.sessMu.Lock() s.deviceID = hello.DeviceID s.sessionID = sid s.sessMu.Unlock() s.recordSession(sid) s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) s.startControlLoop(ctx, sess, stream) return true } func (s *Server) servePeer(ps *peerSession) { if !s.acceptPeerHandshake(ps) { s.removePeerSession(ps.peerID, "closed") return } for { if s.stopping() { return } stream, err := ps.session.AcceptStream() if err != nil { if s.stopping() { return } logger.Debugf("AcceptStream(peer=%s) returned %v - closing peer session", ps.peerID, err) s.removePeerSession(ps.peerID, "closed") return } s.wg.Add(1) go func() { defer s.wg.Done() s.handleStream(context.Background(), stream, ps.sessionID) }() } } func (s *Server) acceptPeerHandshake(ps *peerSession) bool { stream, err := ps.session.AcceptStream() if err != nil { if !s.stopping() { logger.Debugf("AcceptStream(control peer=%s) returned %v", ps.peerID, err) } return false } _ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout)) hello, sid, err := handshake.Server(stream, s.authHook) _ = stream.SetDeadline(time.Time{}) if err != nil { logger.Warnf("handshake failed peer=%s: %v", ps.peerID, err) _ = stream.Close() return false } ps.controlStrm = stream ps.deviceID = hello.DeviceID ps.sessionID = sid s.recordSession(sid) s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s peer=%s)", sid, hello.DeviceID, ps.peerID) s.startPeerControlLoop(ps, stream) return true } func (s *Server) resetLinkPeer() { s.sessMu.RLock() ln := s.ln s.sessMu.RUnlock() if resetter, ok := ln.(interface{ ResetPeer() }); ok { resetter.ResetPeer() } } func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) { controlCtx, stop := context.WithCancel(ctx) s.sessMu.Lock() s.controlStrm = stream s.controlStop = stop s.sessMu.Unlock() liveness := s.liveness onPong := liveness.OnPong onMissedPong := liveness.OnMissedPong onUnhealthy := liveness.OnUnhealthy liveness.OnPong = func(h control.Health) { s.sessMu.RLock() sid := s.sessionID s.sessMu.RUnlock() s.recordPong(h) logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) if onPong != nil { onPong(h) } } liveness.OnMissedPong = func(missed int) { s.recordMissed(missed) logger.Warnf("control missed pong on server: missed_pongs=%d", missed) if onMissedPong != nil { onMissedPong(missed) } } liveness.OnUnhealthy = func(missed int) { s.recordUnhealthy(missed) logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed) if onUnhealthy != nil { onUnhealthy(missed) } } s.wg.Add(1) go func() { defer s.wg.Done() defer func() { _ = stream.Close() }() err := control.Run(controlCtx, stream, liveness) if controlCtx.Err() != nil || ctx.Err() != nil { return } if err != nil { logger.Warnf("server control stream ended: %v", err) } s.recordReconnect() logger.Infof("server reconnect reason=liveness - reinstalling smux session") s.resetLinkPeer() s.reinstallSession(sess) }() } func (s *Server) startPeerControlLoop(ps *peerSession, stream *smux.Stream) { controlCtx, stop := context.WithCancel(context.Background()) ps.controlStop = stop liveness := s.liveness onPong := liveness.OnPong onMissedPong := liveness.OnMissedPong onUnhealthy := liveness.OnUnhealthy liveness.OnPong = func(h control.Health) { s.recordPong(h) logger.Debugf("control alive session=%s peer=%s rtt=%v seq=%d", ps.sessionID, ps.peerID, h.RTT, h.Seq) if onPong != nil { onPong(h) } } liveness.OnMissedPong = func(missed int) { s.recordMissed(missed) logger.Warnf("control missed pong on server: session=%s peer=%s missed_pongs=%d", ps.sessionID, ps.peerID, missed) if onMissedPong != nil { onMissedPong(missed) } } liveness.OnUnhealthy = func(missed int) { s.recordUnhealthy(missed) logger.Warnf("control stream unhealthy on server: session=%s peer=%s missed_pongs=%d", ps.sessionID, ps.peerID, missed) if onUnhealthy != nil { onUnhealthy(missed) } } s.wg.Add(1) go func() { defer s.wg.Done() defer func() { _ = stream.Close() }() err := control.Run(controlCtx, stream, liveness) if controlCtx.Err() != nil || s.stopping() { return } if err != nil { logger.Warnf("server control stream ended session=%s peer=%s: %v", ps.sessionID, ps.peerID, err) } s.recordReconnect() s.removePeerSession(ps.peerID, "reconnect") }() } func (s *Server) stopping() bool { select { case <-s.done: return true default: return false } } // Status returns the latest server-side control health snapshot. func (s *Server) Status() control.Status { return s.health.Status() } func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) } func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) } func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) } func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) } func (s *Server) recordReconnect() { s.health.RecordReconnect() } func (s *Server) shutdown() { if s.done != nil { s.doneOnce.Do(func() { close(s.done) }) } s.closeSession() if s.ln != nil { _ = s.ln.Close() } } func (s *Server) handleStream(_ context.Context, stream *smux.Stream, sessionID string) { defer func() { _ = stream.Close() }() if sessionID == "" { sessionID = s.currentSessionID() } // Read the connect JSON. The client writes the whole JSON in one // stream.Write so it usually arrives intact; tolerate fragmentation // by reading incrementally up to a sane cap. const maxConnReq = 4096 header := make([]byte, 0, 256) tmp := make([]byte, 256) _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) for { n, err := stream.Read(tmp) if n > 0 { header = append(header, tmp[:n]...) if req, ok := parseConnectRequest(header); ok { _ = stream.SetReadDeadline(time.Time{}) s.dispatch(stream, req, sessionID) return } } if err != nil { return } if len(header) > maxConnReq { return } } } func parseConnectRequest(buf []byte) (ConnectRequest, bool) { var req ConnectRequest if err := json.Unmarshal(buf, &req); err != nil { return req, false } if req.Cmd != connectCommand { return req, false } return req, true } // defaultAuthHook admits every client and assigns a random session ID. // Replace it via [Config.AuthHook] to plug in real authorization. func defaultAuthHook(_ string, _ map[string]any) (string, error) { return uuid.NewString(), nil } func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest, sessionID string) { addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) logger.Infof("sid=%d connect %s", stream.ID(), addr) dialStart := time.Now() conn, err := s.dial(req) dialElapsed := time.Since(dialStart) if err != nil { logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err) return } defer func() { _ = conn.Close() }() logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed) if _, err := stream.Write([]byte{0x00}); err != nil { return } var bytesOut uint64 done := make(chan struct{}) go func() { n, _ := io.Copy(stream, conn) if n > 0 { bytesOut = uint64(n) } _ = stream.Close() close(done) }() in, _ := io.Copy(conn, stream) _ = conn.Close() <-done bytesIn := uint64(0) if in > 0 { bytesIn = uint64(in) } if s.onTraffic != nil { s.onTraffic(sessionID, addr, bytesIn, bytesOut) } } func (s *Server) dial(req ConnectRequest) (net.Conn, error) { addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) if s.socksProxyAddr == "" { dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, Resolver: s.resolver, } conn, err := dialer.Dial("tcp4", addr) if err != nil { return nil, fmt.Errorf("dial failed: %w", err) } return conn, nil } proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort)) dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, } conn, err := dialer.Dial("tcp4", proxyAddr) if err != nil { return nil, fmt.Errorf("failed to dial proxy: %w", err) } if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil { _ = conn.Close() return nil, err } return conn, nil } func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { if _, err := conn.Write([]byte{5, 1, 0}); err != nil { return fmt.Errorf("failed to write socks5 auth: %w", err) } resp := make([]byte, 2) if _, err := io.ReadFull(conn, resp); err != nil { return fmt.Errorf("failed to read socks5 auth resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { return ErrSocks5AuthFailed } addrLen := len(targetAddr) if addrLen > 255 { addrLen = 255 targetAddr = targetAddr[:255] } req := make([]byte, 0, 7+addrLen) req = append(req, 5, 1, 0, 3, byte(addrLen)) req = append(req, []byte(targetAddr)...) req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic if _, err := conn.Write(req); err != nil { return fmt.Errorf("failed to write socks5 connect req: %w", err) } resp = make([]byte, 10) if _, err := io.ReadFull(conn, resp); err != nil { return fmt.Errorf("failed to read socks5 connect resp: %w", err) } if resp[0] != 5 || resp[1] != 0 { return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1]) } return nil }