diff --git a/internal/client/client.go b/internal/client/client.go index 0d53215..dca6c48 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,7 +4,6 @@ package client import ( "context" "encoding/binary" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -23,6 +22,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/runtime" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -33,7 +33,8 @@ var ( // ErrProxyAuth is returned when SOCKS proxy authentication fails. ErrProxyAuth = errors.New("SOCKS proxy auth failed") // ErrKeySize is returned when the encryption key is not 32 bytes. - ErrKeySize = errors.New("key must be 32 bytes") + // Re-exported from runtime for compatibility with errors.Is callers. + ErrKeySize = runtime.ErrKeySize // ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5. ErrInvalidSOCKSVersion = errors.New("invalid socks version") // ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands. @@ -58,9 +59,7 @@ type Client struct { controlStop context.CancelFunc sessMu sync.RWMutex reconnectMu sync.Mutex - healthMu sync.RWMutex - health control.Status - onHealth HealthFunc + health *runtime.HealthTracker deviceID string sessionID string claims map[string]any @@ -134,7 +133,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error { dnsServer: cfg.DNSServer, socksUser: cfg.SOCKSUser, socksPass: cfg.SOCKSPass, - onHealth: cfg.OnHealth, + health: runtime.NewHealthTracker(cfg.OnHealth), } // shutdown is registered BEFORE bringUpLink so we always close any @@ -303,27 +302,12 @@ func resolveDeviceID(deviceID, path string) (string, error) { return id, nil } -// smuxConfig returns the tuned smux config used on both ends. -func smuxConfig(maxWirePayload ...int) *smux.Config { - cfg := smux.DefaultConfig() - cfg.Version = 2 - cfg.KeepAliveDisabled = true - cfg.MaxFrameSize = 32768 - if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { - maxFrameSize := maxWirePayload[0] - crypto.WireOverhead - if maxFrameSize < cfg.MaxFrameSize { - cfg.MaxFrameSize = maxFrameSize - } - } - cfg.MaxReceiveBuffer = 16 * 1024 * 1024 - cfg.MaxStreamBuffer = 1024 * 1024 - cfg.KeepAliveInterval = 10 * time.Second - cfg.KeepAliveTimeout = 60 * time.Second - return cfg +func smuxConfig(maxWirePayload int) *smux.Config { + return runtime.SmuxConfig(maxWirePayload) } func linkMaxPayload(tr transport.Transport) int { - return tr.Features().MaxPayloadSize + return runtime.MaxPayload(tr) } func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool { @@ -481,61 +465,14 @@ func (c *Client) startControlLoop( // Status returns the latest client-side control health snapshot. func (c *Client) Status() control.Status { - c.healthMu.RLock() - defer c.healthMu.RUnlock() - return c.health + return c.health.Status() } -func (c *Client) recordSession(sessionID string) { - c.healthMu.Lock() - c.health.SessionID = sessionID - c.health.MissedPongs = 0 - status := c.health - c.healthMu.Unlock() - c.notifyHealth(status) -} - -func (c *Client) recordPong(h control.Health) { - c.healthMu.Lock() - c.health.LastPong = h.LastSeen - c.health.LastRTT = h.RTT - c.health.MissedPongs = 0 - status := c.health - c.healthMu.Unlock() - c.notifyHealth(status) -} - -func (c *Client) recordMissed(missed int) { - c.healthMu.Lock() - c.health.MissedPongs = missed - status := c.health - c.healthMu.Unlock() - c.notifyHealth(status) -} - -func (c *Client) recordUnhealthy(missed int) { - c.healthMu.Lock() - c.health.MissedPongs = missed - c.health.UnhealthyEvents++ - c.health.LastUnhealthy = time.Now() - status := c.health - c.healthMu.Unlock() - c.notifyHealth(status) -} - -func (c *Client) recordReconnect() { - c.healthMu.Lock() - c.health.Reconnects++ - status := c.health - c.healthMu.Unlock() - c.notifyHealth(status) -} - -func (c *Client) notifyHealth(status control.Status) { - if c.onHealth != nil { - c.onHealth(status) - } -} +func (c *Client) recordSession(sessionID string) { c.health.RecordSession(sessionID) } +func (c *Client) recordPong(h control.Health) { c.health.RecordPong(h) } +func (c *Client) recordMissed(missed int) { c.health.RecordMissed(missed) } +func (c *Client) recordUnhealthy(missed int) { c.health.RecordUnhealthy(missed) } +func (c *Client) recordReconnect() { c.health.RecordReconnect() } func (c *Client) shutdown() { c.sessMu.Lock() @@ -567,19 +504,7 @@ func (c *Client) shutdown() { } func setupCipher(keyHex string) (*crypto.Cipher, error) { - key, err := hex.DecodeString(keyHex) - if err != nil { - return nil, fmt.Errorf("failed to decode key: %w", err) - } - if len(key) != 32 { - return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key)) - } - - cipher, err := crypto.NewCipher(string(key)) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - return cipher, nil + return runtime.SetupCipher(keyHex) } func (c *Client) onData(data []byte) { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index d15229a..590d63e 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -14,6 +14,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" + "github.com/openlibrecommunity/olcrtc/internal/runtime" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -46,9 +47,9 @@ func TestSetupCipherRejectsBadInput(t *testing.T) { } func TestSmuxConfig(t *testing.T) { - cfg := smuxConfig() + cfg := smuxConfig(0) if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { - t.Fatalf("smuxConfig() = %+v", cfg) + t.Fatalf("smuxConfig(0) = %+v", cfg) } capped := smuxConfig(4096) if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { @@ -403,12 +404,12 @@ func TestSendConnectRequestOverSmux(t *testing.T) { _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) } @@ -457,12 +458,12 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) { _ = a.Close() _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) } @@ -534,12 +535,12 @@ func TestStartControlLoopReportsPong(t *testing.T) { _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) } @@ -562,7 +563,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() got := make(chan control.Health, 1) - c := &Client{sessionID: "sid-control"} + c := &Client{sessionID: "sid-control", health: runtime.NewHealthTracker(nil)} c.recordSession("sid-control") c.startControlLoop(ctx, Config{ Liveness: control.Config{ @@ -604,7 +605,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) { updates := 0 - c := &Client{onHealth: func(control.Status) { updates++ }} + c := &Client{health: runtime.NewHealthTracker(func(control.Status) { updates++ })} c.recordSession("sid-1") c.recordMissed(2) c.recordUnhealthy(3) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go new file mode 100644 index 0000000..1f9b838 --- /dev/null +++ b/internal/runtime/runtime.go @@ -0,0 +1,154 @@ +// Package runtime holds infrastructure shared by the olcrtc server and +// client: smux tuning, cipher setup, and control-stream health bookkeeping. +// The lifecycle differences between server and client (accept loop / SOCKS5 +// dial vs. SOCKS5 listener / tunnel) live in their respective packages. +package runtime + +import ( + "encoding/hex" + "errors" + "fmt" + "sync" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/xtaci/smux" +) + +// ErrKeyRequired is returned when no encryption key is provided. +var ErrKeyRequired = errors.New("key required (use -key )") + +// ErrKeySize is returned when the encryption key is not 32 bytes. +var ErrKeySize = errors.New("key must be 32 bytes") + +// SetupCipher decodes a 64-char hex key and instantiates the AEAD cipher. +func SetupCipher(keyHex string) (*crypto.Cipher, error) { + if keyHex == "" { + return nil, ErrKeyRequired + } + key, err := hex.DecodeString(keyHex) + if err != nil { + return nil, fmt.Errorf("failed to decode key: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key)) + } + cipher, err := crypto.NewCipher(string(key)) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + return cipher, nil +} + +// SmuxConfig returns the tuned smux config used on both ends. Both peers +// must agree on Version and MaxFrameSize. maxWirePayload, when > 0, +// constrains the max frame size to fit under the transport's per-message +// payload cap minus the AEAD wire overhead. +func SmuxConfig(maxWirePayload int) *smux.Config { + cfg := smux.DefaultConfig() + cfg.Version = 2 + cfg.KeepAliveDisabled = true + cfg.MaxFrameSize = 32768 + if maxWirePayload > crypto.WireOverhead { + maxFrameSize := maxWirePayload - crypto.WireOverhead + if maxFrameSize < cfg.MaxFrameSize { + cfg.MaxFrameSize = maxFrameSize + } + } + cfg.MaxReceiveBuffer = 16 * 1024 * 1024 + cfg.MaxStreamBuffer = 1024 * 1024 + cfg.KeepAliveInterval = 10 * time.Second + cfg.KeepAliveTimeout = 60 * time.Second + return cfg +} + +// MaxPayload reports the transport's per-message payload limit. Returns 0 +// when the transport sets no explicit limit; the caller treats 0 as "use +// SmuxConfig's default frame size". +func MaxPayload(tr transport.Transport) int { + return tr.Features().MaxPayloadSize +} + +// HealthTracker holds the live snapshot of one side's control-stream +// health: last pong time, last RTT, miss counts, reconnect counts. +// Server and client both embed a HealthTracker to avoid open-coding the +// same record* methods on both sides. +type HealthTracker struct { + mu sync.RWMutex + status control.Status + notify func(control.Status) +} + +// NewHealthTracker creates a HealthTracker that publishes the latest +// snapshot through notify whenever it changes. notify may be nil. +func NewHealthTracker(notify func(control.Status)) *HealthTracker { + if notify == nil { + notify = func(control.Status) {} + } + return &HealthTracker{notify: notify} +} + +// Status returns the latest health snapshot. A nil tracker reports a zero +// value, which lets tests instantiate stub Server/Client structs without +// wiring up a real tracker. +func (h *HealthTracker) Status() control.Status { + if h == nil { + return control.Status{} + } + h.mu.RLock() + defer h.mu.RUnlock() + return h.status +} + +// RecordSession resets miss counters and stamps the session id. +func (h *HealthTracker) RecordSession(id string) { + h.update(func(s *control.Status) { + s.SessionID = id + s.MissedPongs = 0 + }) +} + +// RecordPong updates LastPong/LastRTT and clears MissedPongs. +func (h *HealthTracker) RecordPong(p control.Health) { + h.update(func(s *control.Status) { + s.LastPong = p.LastSeen + s.LastRTT = p.RTT + s.MissedPongs = 0 + }) +} + +// RecordMissed bumps the missed-pong count. +func (h *HealthTracker) RecordMissed(missed int) { + h.update(func(s *control.Status) { + s.MissedPongs = missed + }) +} + +// RecordUnhealthy bumps the unhealthy-event count and stamps the time. +func (h *HealthTracker) RecordUnhealthy(missed int) { + h.update(func(s *control.Status) { + s.MissedPongs = missed + s.UnhealthyEvents++ + s.LastUnhealthy = time.Now() + }) +} + +// RecordReconnect bumps the reconnect counter. +func (h *HealthTracker) RecordReconnect() { + h.update(func(s *control.Status) { + s.Reconnects++ + }) +} + +func (h *HealthTracker) update(mutate func(*control.Status)) { + if h == nil { + return + } + h.mu.Lock() + mutate(&h.status) + snapshot := h.status + h.mu.Unlock() + h.notify(snapshot) +} diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go new file mode 100644 index 0000000..a0f44eb --- /dev/null +++ b/internal/runtime/runtime_test.go @@ -0,0 +1,84 @@ +package runtime_test + +import ( + "errors" + "testing" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/runtime" +) + +func TestSetupCipherErrors(t *testing.T) { + if _, err := runtime.SetupCipher(""); !errors.Is(err, runtime.ErrKeyRequired) { + t.Fatalf("empty key error = %v, want ErrKeyRequired", err) + } + if _, err := runtime.SetupCipher("notHex"); err == nil { + t.Fatalf("bad hex error = nil") + } + if _, err := runtime.SetupCipher("00"); !errors.Is(err, runtime.ErrKeySize) { + t.Fatalf("short key error = %v, want ErrKeySize", err) + } +} + +func TestSetupCipherSuccess(t *testing.T) { + key := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + c, err := runtime.SetupCipher(key) + if err != nil { + t.Fatalf("SetupCipher() error = %v", err) + } + if c == nil { + t.Fatal("SetupCipher() returned nil cipher") + } +} + +func TestSmuxConfigDefault(t *testing.T) { + cfg := runtime.SmuxConfig(0) + if cfg.Version != 2 || cfg.MaxFrameSize != 32768 { + t.Fatalf("SmuxConfig(0) = %+v", cfg) + } +} + +func TestSmuxConfigShrinks(t *testing.T) { + // 100-byte wire payload minus crypto overhead is far below default 32768, + // so MaxFrameSize must shrink. + cfg := runtime.SmuxConfig(100) + if cfg.MaxFrameSize >= 32768 { + t.Fatalf("MaxFrameSize = %d, want shrunk", cfg.MaxFrameSize) + } +} + +func TestHealthTrackerEmitsOnEveryChange(t *testing.T) { + var got []control.Status + h := runtime.NewHealthTracker(func(s control.Status) { + got = append(got, s) + }) + + h.RecordSession("s1") + h.RecordPong(control.Health{LastSeen: time.Unix(100, 0), RTT: time.Millisecond}) + h.RecordMissed(2) + h.RecordReconnect() + h.RecordUnhealthy(3) + + if len(got) != 5 { + t.Fatalf("notify count = %d, want 5", len(got)) + } + if got[0].SessionID != "s1" { + t.Fatalf("first snapshot session id = %q", got[0].SessionID) + } + if got[1].LastRTT != time.Millisecond { + t.Fatalf("second snapshot rtt = %v", got[1].LastRTT) + } + final := h.Status() + if final.Reconnects != 1 || final.UnhealthyEvents != 1 || final.MissedPongs != 3 { + t.Fatalf("final snapshot = %+v", final) + } +} + +func TestHealthTrackerNilNotifyOK(t *testing.T) { + h := runtime.NewHealthTracker(nil) + h.RecordSession("s") // must not panic + if h.Status().SessionID != "s" { + t.Fatal("Status() did not record without notify") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 882a8e8..df746c3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -20,6 +19,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/runtime" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -27,10 +27,11 @@ import ( const connectCommand = "connect" var ( - // ErrKeyRequired is returned when no encryption key is provided. - ErrKeyRequired = errors.New("key required (use -key )") - // ErrKeySize is returned when the encryption key is not 32 bytes. - ErrKeySize = errors.New("key must be 32 bytes") + // ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with + // pre-runtime callers that errors.Is-checked it. + ErrKeyRequired = runtime.ErrKeyRequired + // ErrKeySize re-exports runtime.ErrKeySize for the same reason. + ErrKeySize = runtime.ErrKeySize // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. @@ -62,13 +63,11 @@ type Server struct { controlStop context.CancelFunc sessMu sync.RWMutex reinstallMu sync.Mutex - healthMu sync.RWMutex wg sync.WaitGroup authHook handshake.AuthFunc onOpen SessionOpenFunc onClose SessionCloseFunc onTraffic TrafficFunc - onHealth HealthFunc deviceID string sessionID string dnsServer string @@ -76,7 +75,7 @@ type Server struct { socksProxyAddr string socksProxyPort int liveness control.Config - health control.Status + health *runtime.HealthTracker } // ConnectRequest is a message from the client to establish a new connection. @@ -143,22 +142,17 @@ func Run(ctx context.Context, cfg Config) error { if onTraffic == nil { onTraffic = func(string, string, uint64, uint64) {} } - onHealth := cfg.OnHealth - if onHealth == nil { - onHealth = func(control.Status) {} - } - s := &Server{ cipher: cipher, authHook: hook, onOpen: onOpen, onClose: onClose, onTraffic: onTraffic, - onHealth: onHealth, dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, liveness: cfg.Liveness, + health: runtime.NewHealthTracker(cfg.OnHealth), } s.setupResolver() @@ -189,23 +183,7 @@ func Run(ctx context.Context, cfg Config) error { } func setupCipher(keyHex string) (*crypto.Cipher, error) { - if keyHex == "" { - return nil, ErrKeyRequired - } - - key, err := hex.DecodeString(keyHex) - if err != nil { - return nil, fmt.Errorf("failed to decode key: %w", err) - } - if len(key) != 32 { - return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key)) - } - - cipher, err := crypto.NewCipher(string(key)) - if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - return cipher, nil + return runtime.SetupCipher(keyHex) } func (s *Server) setupResolver() { @@ -218,28 +196,12 @@ func (s *Server) setupResolver() { } } -// smuxConfig mirrors the client side. Both peers must agree on Version and -// MaxFrameSize. -func smuxConfig(maxWirePayload ...int) *smux.Config { - cfg := smux.DefaultConfig() - cfg.Version = 2 - cfg.KeepAliveDisabled = true - cfg.MaxFrameSize = 32768 - if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { - maxFrameSize := maxWirePayload[0] - crypto.WireOverhead - if maxFrameSize < cfg.MaxFrameSize { - cfg.MaxFrameSize = maxFrameSize - } - } - cfg.MaxReceiveBuffer = 16 * 1024 * 1024 - cfg.MaxStreamBuffer = 1024 * 1024 - cfg.KeepAliveInterval = 10 * time.Second - cfg.KeepAliveTimeout = 60 * time.Second - return cfg +func smuxConfig(maxWirePayload int) *smux.Config { + return runtime.SmuxConfig(maxWirePayload) } func linkMaxPayload(tr transport.Transport) int { - return tr.Features().MaxPayloadSize + return runtime.MaxPayload(tr) } func (s *Server) bringUpLink( @@ -548,61 +510,14 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea // Status returns the latest server-side control health snapshot. func (s *Server) Status() control.Status { - s.healthMu.RLock() - defer s.healthMu.RUnlock() - return s.health + return s.health.Status() } -func (s *Server) recordSession(sessionID string) { - s.healthMu.Lock() - s.health.SessionID = sessionID - s.health.MissedPongs = 0 - status := s.health - s.healthMu.Unlock() - s.notifyHealth(status) -} - -func (s *Server) recordPong(h control.Health) { - s.healthMu.Lock() - s.health.LastPong = h.LastSeen - s.health.LastRTT = h.RTT - s.health.MissedPongs = 0 - status := s.health - s.healthMu.Unlock() - s.notifyHealth(status) -} - -func (s *Server) recordMissed(missed int) { - s.healthMu.Lock() - s.health.MissedPongs = missed - status := s.health - s.healthMu.Unlock() - s.notifyHealth(status) -} - -func (s *Server) recordUnhealthy(missed int) { - s.healthMu.Lock() - s.health.MissedPongs = missed - s.health.UnhealthyEvents++ - s.health.LastUnhealthy = time.Now() - status := s.health - s.healthMu.Unlock() - s.notifyHealth(status) -} - -func (s *Server) recordReconnect() { - s.healthMu.Lock() - s.health.Reconnects++ - status := s.health - s.healthMu.Unlock() - s.notifyHealth(status) -} - -func (s *Server) notifyHealth(status control.Status) { - if s.onHealth != nil { - s.onHealth(status) - } -} +func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) } +func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) } +func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) } +func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) } +func (s *Server) recordReconnect() { s.health.RecordReconnect() } func (s *Server) shutdown() { s.closeSession() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 67ce828..9512f8d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -14,6 +14,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" + "github.com/openlibrecommunity/olcrtc/internal/runtime" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -47,9 +48,9 @@ func TestSetupCipherRejectsBadInput(t *testing.T) { } func TestSmuxConfig(t *testing.T) { - cfg := smuxConfig() + cfg := smuxConfig(0) if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { - t.Fatalf("smuxConfig() = %+v", cfg) + t.Fatalf("smuxConfig(0) = %+v", cfg) } capped := smuxConfig(4096) if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { @@ -321,12 +322,12 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) { _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) } @@ -389,12 +390,12 @@ func TestStartControlLoopReportsPong(t *testing.T) { _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) } @@ -418,6 +419,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { got := make(chan control.Health, 1) s := &Server{ sessionID: "sid-control", + health: runtime.NewHealthTracker(nil), liveness: control.Config{ Interval: 10 * time.Millisecond, Timeout: 100 * time.Millisecond, @@ -463,7 +465,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) { updates := 0 - s := &Server{onHealth: func(control.Status) { updates++ }} + s := &Server{health: runtime.NewHealthTracker(func(control.Status) { updates++ })} s.recordSession("sid-1") s.recordMissed(2) s.recordUnhealthy(3) @@ -504,12 +506,12 @@ func TestDispatchFiresOnTraffic(t *testing.T) { _ = b.Close() }() - serverSess, err := smux.Server(a, smuxConfig()) + serverSess, err := smux.Server(a, smuxConfig(0)) if err != nil { t.Fatalf("smux.Server() error = %v", err) } defer func() { _ = serverSess.Close() }() - clientSess, err := smux.Client(b, smuxConfig()) + clientSess, err := smux.Client(b, smuxConfig(0)) if err != nil { t.Fatalf("smux.Client() error = %v", err) }