Files
olcrtc/internal/server/server.go
zarazaex69 f469bd72af refactor: extract shared session runtime into internal/runtime
server.go and client.go each carried byte-identical copies of
smuxConfig (~20 lines), setupCipher (~18 lines), and the health
bookkeeping pair recordSession/Pong/Missed/Unhealthy/Reconnect plus a
private healthMu+status+notifyHealth scaffold. Same code, twice.

Add internal/runtime exposing:
- SetupCipher, SmuxConfig, MaxPayload — common construction helpers,
  ErrKeyRequired/ErrKeySize re-exported from runtime so existing
  errors.Is checks on server.ErrKeyRequired etc. keep working.
- HealthTracker — nil-safe wrapper around control.Status with
  RecordSession/Pong/Missed/Unhealthy/Reconnect that publishes through an
  OnHealth callback supplied at construction.

server and client now hold a *runtime.HealthTracker instead of their own
mu+status+notify scaffolds. recordX methods on Server/Client are now
one-liners that forward to the tracker. smuxConfig(0) replaces the prior
variadic smuxConfig() in test call sites; nil-safe Status()/update() on
HealthTracker means tests that build raw &Server{}/&Client{} no longer
need to wire up a tracker for the records to be no-ops.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-16 14:24:46 +03:00

691 lines
18 KiB
Go

// Package server implements the olcrtc tunnel server logic.
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/runtime"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
const connectCommand = "connect"
var (
// ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with
// pre-runtime callers that errors.Is-checked it.
ErrKeyRequired = runtime.ErrKeyRequired
// ErrKeySize re-exports runtime.ErrKeySize for the same reason.
ErrKeySize = runtime.ErrKeySize
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
)
// SessionOpenFunc is called after a successful handshake, before the server
// accepts tunnel streams on that session.
type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any)
// SessionCloseFunc is called when a session is torn down. Possible reasons:
// "reconnect" (carrier dropped and was reestablished), "closed" (graceful
// shutdown or ctx cancel).
type SessionCloseFunc func(sessionID, reason string)
// TrafficFunc is called once per tunnel stream, after the copy loops finish.
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
// HealthFunc is called when the server control health snapshot changes.
type HealthFunc func(control.Status)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
ln transport.Transport
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStop context.CancelFunc
sessMu sync.RWMutex
reinstallMu sync.Mutex
wg sync.WaitGroup
authHook handshake.AuthFunc
onOpen SessionOpenFunc
onClose SessionCloseFunc
onTraffic TrafficFunc
deviceID string
sessionID string
dnsServer string
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
liveness control.Config
health *runtime.HealthTracker
}
// ConnectRequest is a message from the client to establish a new connection.
type ConnectRequest struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
// Config holds runtime configuration for [Run].
type Config struct {
Transport string
Carrier string
RoomURL string
ChannelID string
KeyHex string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
TransportOptions transport.Options
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
AuthHook handshake.AuthFunc
// OnSessionOpen fires after a successful handshake. Nil means no-op.
OnSessionOpen SessionOpenFunc
// OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op.
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
OnTraffic TrafficFunc
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the server with the given configuration.
func Run(ctx context.Context, cfg Config) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(cfg.KeyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
hook := cfg.AuthHook
if hook == nil {
hook = defaultAuthHook
}
onOpen := cfg.OnSessionOpen
if onOpen == nil {
onOpen = func(string, string, map[string]any) {}
}
onClose := cfg.OnSessionClose
if onClose == nil {
onClose = func(string, string) {}
}
onTraffic := cfg.OnTraffic
if onTraffic == nil {
onTraffic = func(string, string, uint64, uint64) {}
}
s := &Server{
cipher: cipher,
authHook: hook,
onOpen: onOpen,
onClose: onClose,
onTraffic: onTraffic,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
liveness: cfg.Liveness,
health: runtime.NewHealthTracker(cfg.OnHealth),
}
s.setupResolver()
// Register shutdown BEFORE bringUpLink so a partial setup (e.g.
// link.New succeeded but ln.Connect timed out) still tears the
// link down and sends MUC presence-unavailable. Without this, an
// early bringUpLink error returns straight to the caller and the
// already-joined MUC presence stays behind as a ghost participant
// for subsequent tests against the same room. shutdown is
// idempotent and safe to call before s.serve runs.
defer func() {
s.shutdown()
s.wg.Wait()
}()
if err := s.bringUpLink(runCtx, cfg, cancel); err != nil {
return err
}
go func() {
<-runCtx.Done()
s.closeSession()
}()
s.serve(runCtx)
return nil
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
return runtime.SetupCipher(keyHex)
}
func (s *Server) setupResolver() {
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
d := net.Dialer{Timeout: 3 * time.Second}
return d.DialContext(ctx, network, s.dnsServer)
},
}
}
func smuxConfig(maxWirePayload int) *smux.Config {
return runtime.SmuxConfig(maxWirePayload)
}
func linkMaxPayload(tr transport.Transport) int {
return runtime.MaxPayload(tr)
}
func (s *Server) bringUpLink(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
) error {
ln, err := transport.New(ctx, cfg.Transport, transport.Config{
Carrier: cfg.Carrier,
RoomURL: cfg.RoomURL,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ChannelID: cfg.ChannelID,
DeviceID: "",
Name: names.Generate(),
OnData: s.onData,
DNSServer: s.dnsServer,
ProxyAddr: s.socksProxyAddr,
ProxyPort: s.socksProxyPort,
Options: cfg.TransportOptions,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create transport: %w", err)
}
s.ln = ln
ln.SetEndedCallback(func(reason string) {
logger.Infof("Server link reported conference end: %s", reason)
cancel()
})
ln.SetShouldReconnect(func() bool { return ctx.Err() == nil })
ln.SetReconnectCallback(func() {
if ctx.Err() != nil {
return
}
s.handleReconnect()
})
logger.Infof("Connecting transport=%s carrier=%s ...", cfg.Transport, cfg.Carrier)
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
logger.Infof("Link connected")
s.installSession()
s.wg.Add(1)
go func() {
defer s.wg.Done()
ln.WatchConnection(ctx)
}()
return nil
}
func (s *Server) installSession() {
conn := muxconn.New(s.ln, s.cipher)
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
return
}
s.sessMu.Lock()
s.conn = conn
s.session = sess
s.sessMu.Unlock()
}
func (s *Server) handleReconnect() {
s.recordReconnect()
logger.Infof("server reconnect reason=carrier - tearing down smux session")
s.sessMu.RLock()
current := s.session
s.sessMu.RUnlock()
s.reinstallSession(current)
}
func (s *Server) reinstallSession(dead *smux.Session) {
s.reinstallMu.Lock()
defer s.reinstallMu.Unlock()
// Pre-build the replacement so we can swap atomically below.
newConn := muxconn.New(s.ln, s.cipher)
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
_ = newConn.Close()
return
}
s.sessMu.Lock()
if s.session != dead {
// Someone else already reinstalled — discard our build.
s.sessMu.Unlock()
_ = newSess.Close()
_ = newConn.Close()
return
}
oldSess := s.session
oldConn := s.conn
oldControlStop := s.controlStop
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.controlStop = nil
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
if oldSID != "" {
s.onClose(oldSID, "reconnect")
}
}
func (s *Server) closeSession() {
s.sessMu.Lock()
sess := s.session
conn := s.conn
controlStop := s.controlStop
s.session = nil
s.conn = nil
s.controlStop = nil
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if sess != nil {
_ = sess.Close()
}
if conn != nil {
_ = conn.Close()
}
if oldSID != "" {
s.onClose(oldSID, "closed")
}
}
func (s *Server) onData(data []byte) {
s.sessMu.RLock()
conn := s.conn
s.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
// serve drives the smux Accept loop. The first accepted stream on a given
// smux session is the control stream — the handshake runs there. Subsequent
// streams are tunnel streams and proxy traffic.
func (s *Server) serve(ctx context.Context) {
for {
if contextDone(ctx) {
return
}
s.sessMu.RLock()
sess := s.session
s.sessMu.RUnlock()
if sess == nil {
select {
case <-ctx.Done():
return
case <-time.After(50 * time.Millisecond):
continue
}
}
if !s.handshakeReady() {
if !s.acceptHandshake(ctx, sess) {
continue
}
}
stream, err := sess.AcceptStream()
if err != nil {
if contextDone(ctx) {
return
}
logger.Debugf("AcceptStream returned %v - reinstalling session", err)
s.reinstallSession(sess)
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleStream(ctx, stream)
}()
}
}
func contextDone(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
// handshakeReady reports whether the current session has completed its
// handshake. The session is reset on reconnect, so this is recomputed.
func (s *Server) handshakeReady() bool {
s.sessMu.RLock()
defer s.sessMu.RUnlock()
return s.sessionID != ""
}
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
stream, err := sess.AcceptStream()
if err != nil {
select {
case <-ctx.Done():
return false
default:
}
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
s.reinstallSession(sess)
return false
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
hello, sid, err := handshake.Server(stream, s.authHook)
_ = stream.SetDeadline(time.Time{})
if err != nil {
logger.Warnf("handshake failed: %v", err)
_ = stream.Close()
s.reinstallSession(sess)
return false
}
s.sessMu.Lock()
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
s.startControlLoop(ctx, sess, stream)
return true
}
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
s.sessMu.Lock()
s.controlStop = stop
s.sessMu.Unlock()
liveness := s.liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
s.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
s.recordMissed(missed)
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
s.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("server control stream ended: %v", err)
}
s.recordReconnect()
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
s.reinstallSession(sess)
}()
}
// Status returns the latest server-side control health snapshot.
func (s *Server) Status() control.Status {
return s.health.Status()
}
func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) }
func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) }
func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) }
func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) }
func (s *Server) recordReconnect() { s.health.RecordReconnect() }
func (s *Server) shutdown() {
s.closeSession()
if s.ln != nil {
_ = s.ln.Close()
}
}
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
defer func() { _ = stream.Close() }()
// Read the connect JSON. The client writes the whole JSON in one
// stream.Write so it usually arrives intact; tolerate fragmentation
// by reading incrementally up to a sane cap.
const maxConnReq = 4096
header := make([]byte, 0, 256)
tmp := make([]byte, 256)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
for {
n, err := stream.Read(tmp)
if n > 0 {
header = append(header, tmp[:n]...)
if req, ok := parseConnectRequest(header); ok {
_ = stream.SetReadDeadline(time.Time{})
s.dispatch(stream, req)
return
}
}
if err != nil {
return
}
if len(header) > maxConnReq {
return
}
}
}
func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
var req ConnectRequest
if err := json.Unmarshal(buf, &req); err != nil {
return req, false
}
if req.Cmd != connectCommand {
return req, false
}
return req, true
}
// defaultAuthHook admits every client and assigns a random session ID.
// Replace it via [Config.AuthHook] to plug in real authorization.
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
return uuid.NewString(), nil
}
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
logger.Infof("sid=%d connect %s", stream.ID(), addr)
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
return
}
defer func() { _ = conn.Close() }()
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
if _, err := stream.Write([]byte{0x00}); err != nil {
return
}
var bytesOut uint64
done := make(chan struct{})
go func() {
n, _ := io.Copy(stream, conn)
if n > 0 {
bytesOut = uint64(n)
}
_ = stream.Close()
close(done)
}()
in, _ := io.Copy(conn, stream)
_ = conn.Close()
<-done
bytesIn := uint64(0)
if in > 0 {
bytesIn = uint64(in)
}
if s.onTraffic != nil {
s.onTraffic(sid, addr, bytesIn, bytesOut)
}
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return fmt.Errorf("failed to write socks5 auth: %w", err)
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return ErrSocks5AuthFailed
}
addrLen := len(targetAddr)
if addrLen > 255 {
addrLen = 255
targetAddr = targetAddr[:255]
}
req := make([]byte, 0, 7+addrLen)
req = append(req, 5, 1, 0, 3, byte(addrLen))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
if _, err := conn.Write(req); err != nil {
return fmt.Errorf("failed to write socks5 connect req: %w", err)
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
}
return nil
}