mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
refactor: extract shared session runtime into internal/runtime
server.go and client.go each carried byte-identical copies of
smuxConfig (~20 lines), setupCipher (~18 lines), and the health
bookkeeping pair recordSession/Pong/Missed/Unhealthy/Reconnect plus a
private healthMu+status+notifyHealth scaffold. Same code, twice.
Add internal/runtime exposing:
- SetupCipher, SmuxConfig, MaxPayload — common construction helpers,
ErrKeyRequired/ErrKeySize re-exported from runtime so existing
errors.Is checks on server.ErrKeyRequired etc. keep working.
- HealthTracker — nil-safe wrapper around control.Status with
RecordSession/Pong/Missed/Unhealthy/Reconnect that publishes through an
OnHealth callback supplied at construction.
server and client now hold a *runtime.HealthTracker instead of their own
mu+status+notify scaffolds. recordX methods on Server/Client are now
one-liners that forward to the tracker. smuxConfig(0) replaces the prior
variadic smuxConfig() in test call sites; nil-safe Status()/update() on
HealthTracker means tests that build raw &Server{}/&Client{} no longer
need to wire up a tracker for the records to be no-ops.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -23,6 +22,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
@@ -33,7 +33,8 @@ var (
|
||||
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
|
||||
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// Re-exported from runtime for compatibility with errors.Is callers.
|
||||
ErrKeySize = runtime.ErrKeySize
|
||||
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
|
||||
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
|
||||
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
|
||||
@@ -58,9 +59,7 @@ type Client struct {
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reconnectMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
health control.Status
|
||||
onHealth HealthFunc
|
||||
health *runtime.HealthTracker
|
||||
deviceID string
|
||||
sessionID string
|
||||
claims map[string]any
|
||||
@@ -134,7 +133,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksUser: cfg.SOCKSUser,
|
||||
socksPass: cfg.SOCKSPass,
|
||||
onHealth: cfg.OnHealth,
|
||||
health: runtime.NewHealthTracker(cfg.OnHealth),
|
||||
}
|
||||
|
||||
// shutdown is registered BEFORE bringUpLink so we always close any
|
||||
@@ -303,27 +302,12 @@ func resolveDeviceID(deviceID, path string) (string, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// smuxConfig returns the tuned smux config used on both ends.
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
cfg.KeepAliveTimeout = 60 * time.Second
|
||||
return cfg
|
||||
func smuxConfig(maxWirePayload int) *smux.Config {
|
||||
return runtime.SmuxConfig(maxWirePayload)
|
||||
}
|
||||
|
||||
func linkMaxPayload(tr transport.Transport) int {
|
||||
return tr.Features().MaxPayloadSize
|
||||
return runtime.MaxPayload(tr)
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool {
|
||||
@@ -481,61 +465,14 @@ func (c *Client) startControlLoop(
|
||||
|
||||
// Status returns the latest client-side control health snapshot.
|
||||
func (c *Client) Status() control.Status {
|
||||
c.healthMu.RLock()
|
||||
defer c.healthMu.RUnlock()
|
||||
return c.health
|
||||
return c.health.Status()
|
||||
}
|
||||
|
||||
func (c *Client) recordSession(sessionID string) {
|
||||
c.healthMu.Lock()
|
||||
c.health.SessionID = sessionID
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordPong(h control.Health) {
|
||||
c.healthMu.Lock()
|
||||
c.health.LastPong = h.LastSeen
|
||||
c.health.LastRTT = h.RTT
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordMissed(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordUnhealthy(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
c.health.UnhealthyEvents++
|
||||
c.health.LastUnhealthy = time.Now()
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordReconnect() {
|
||||
c.healthMu.Lock()
|
||||
c.health.Reconnects++
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) notifyHealth(status control.Status) {
|
||||
if c.onHealth != nil {
|
||||
c.onHealth(status)
|
||||
}
|
||||
}
|
||||
func (c *Client) recordSession(sessionID string) { c.health.RecordSession(sessionID) }
|
||||
func (c *Client) recordPong(h control.Health) { c.health.RecordPong(h) }
|
||||
func (c *Client) recordMissed(missed int) { c.health.RecordMissed(missed) }
|
||||
func (c *Client) recordUnhealthy(missed int) { c.health.RecordUnhealthy(missed) }
|
||||
func (c *Client) recordReconnect() { c.health.RecordReconnect() }
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
@@ -567,19 +504,7 @@ func (c *Client) shutdown() {
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
return runtime.SetupCipher(keyHex)
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
@@ -46,9 +47,9 @@ func TestSetupCipherRejectsBadInput(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSmuxConfig(t *testing.T) {
|
||||
cfg := smuxConfig()
|
||||
cfg := smuxConfig(0)
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
t.Fatalf("smuxConfig(0) = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
@@ -403,12 +404,12 @@ func TestSendConnectRequestOverSmux(t *testing.T) {
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
@@ -457,12 +458,12 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
@@ -534,12 +535,12 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
@@ -562,7 +563,7 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
got := make(chan control.Health, 1)
|
||||
c := &Client{sessionID: "sid-control"}
|
||||
c := &Client{sessionID: "sid-control", health: runtime.NewHealthTracker(nil)}
|
||||
c.recordSession("sid-control")
|
||||
c.startControlLoop(ctx, Config{
|
||||
Liveness: control.Config{
|
||||
@@ -604,7 +605,7 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
c := &Client{onHealth: func(control.Status) { updates++ }}
|
||||
c := &Client{health: runtime.NewHealthTracker(func(control.Status) { updates++ })}
|
||||
c.recordSession("sid-1")
|
||||
c.recordMissed(2)
|
||||
c.recordUnhealthy(3)
|
||||
|
||||
154
internal/runtime/runtime.go
Normal file
154
internal/runtime/runtime.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// Package runtime holds infrastructure shared by the olcrtc server and
|
||||
// client: smux tuning, cipher setup, and control-stream health bookkeeping.
|
||||
// The lifecycle differences between server and client (accept loop / SOCKS5
|
||||
// dial vs. SOCKS5 listener / tunnel) live in their respective packages.
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
// ErrKeyRequired is returned when no encryption key is provided.
|
||||
var ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
var ErrKeySize = errors.New("key must be 32 bytes")
|
||||
|
||||
// SetupCipher decodes a 64-char hex key and instantiates the AEAD cipher.
|
||||
func SetupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
if keyHex == "" {
|
||||
return nil, ErrKeyRequired
|
||||
}
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
|
||||
}
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// SmuxConfig returns the tuned smux config used on both ends. Both peers
|
||||
// must agree on Version and MaxFrameSize. maxWirePayload, when > 0,
|
||||
// constrains the max frame size to fit under the transport's per-message
|
||||
// payload cap minus the AEAD wire overhead.
|
||||
func SmuxConfig(maxWirePayload int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if maxWirePayload > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
cfg.KeepAliveTimeout = 60 * time.Second
|
||||
return cfg
|
||||
}
|
||||
|
||||
// MaxPayload reports the transport's per-message payload limit. Returns 0
|
||||
// when the transport sets no explicit limit; the caller treats 0 as "use
|
||||
// SmuxConfig's default frame size".
|
||||
func MaxPayload(tr transport.Transport) int {
|
||||
return tr.Features().MaxPayloadSize
|
||||
}
|
||||
|
||||
// HealthTracker holds the live snapshot of one side's control-stream
|
||||
// health: last pong time, last RTT, miss counts, reconnect counts.
|
||||
// Server and client both embed a HealthTracker to avoid open-coding the
|
||||
// same record* methods on both sides.
|
||||
type HealthTracker struct {
|
||||
mu sync.RWMutex
|
||||
status control.Status
|
||||
notify func(control.Status)
|
||||
}
|
||||
|
||||
// NewHealthTracker creates a HealthTracker that publishes the latest
|
||||
// snapshot through notify whenever it changes. notify may be nil.
|
||||
func NewHealthTracker(notify func(control.Status)) *HealthTracker {
|
||||
if notify == nil {
|
||||
notify = func(control.Status) {}
|
||||
}
|
||||
return &HealthTracker{notify: notify}
|
||||
}
|
||||
|
||||
// Status returns the latest health snapshot. A nil tracker reports a zero
|
||||
// value, which lets tests instantiate stub Server/Client structs without
|
||||
// wiring up a real tracker.
|
||||
func (h *HealthTracker) Status() control.Status {
|
||||
if h == nil {
|
||||
return control.Status{}
|
||||
}
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.status
|
||||
}
|
||||
|
||||
// RecordSession resets miss counters and stamps the session id.
|
||||
func (h *HealthTracker) RecordSession(id string) {
|
||||
h.update(func(s *control.Status) {
|
||||
s.SessionID = id
|
||||
s.MissedPongs = 0
|
||||
})
|
||||
}
|
||||
|
||||
// RecordPong updates LastPong/LastRTT and clears MissedPongs.
|
||||
func (h *HealthTracker) RecordPong(p control.Health) {
|
||||
h.update(func(s *control.Status) {
|
||||
s.LastPong = p.LastSeen
|
||||
s.LastRTT = p.RTT
|
||||
s.MissedPongs = 0
|
||||
})
|
||||
}
|
||||
|
||||
// RecordMissed bumps the missed-pong count.
|
||||
func (h *HealthTracker) RecordMissed(missed int) {
|
||||
h.update(func(s *control.Status) {
|
||||
s.MissedPongs = missed
|
||||
})
|
||||
}
|
||||
|
||||
// RecordUnhealthy bumps the unhealthy-event count and stamps the time.
|
||||
func (h *HealthTracker) RecordUnhealthy(missed int) {
|
||||
h.update(func(s *control.Status) {
|
||||
s.MissedPongs = missed
|
||||
s.UnhealthyEvents++
|
||||
s.LastUnhealthy = time.Now()
|
||||
})
|
||||
}
|
||||
|
||||
// RecordReconnect bumps the reconnect counter.
|
||||
func (h *HealthTracker) RecordReconnect() {
|
||||
h.update(func(s *control.Status) {
|
||||
s.Reconnects++
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HealthTracker) update(mutate func(*control.Status)) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
mutate(&h.status)
|
||||
snapshot := h.status
|
||||
h.mu.Unlock()
|
||||
h.notify(snapshot)
|
||||
}
|
||||
84
internal/runtime/runtime_test.go
Normal file
84
internal/runtime/runtime_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package runtime_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
||||
)
|
||||
|
||||
func TestSetupCipherErrors(t *testing.T) {
|
||||
if _, err := runtime.SetupCipher(""); !errors.Is(err, runtime.ErrKeyRequired) {
|
||||
t.Fatalf("empty key error = %v, want ErrKeyRequired", err)
|
||||
}
|
||||
if _, err := runtime.SetupCipher("notHex"); err == nil {
|
||||
t.Fatalf("bad hex error = nil")
|
||||
}
|
||||
if _, err := runtime.SetupCipher("00"); !errors.Is(err, runtime.ErrKeySize) {
|
||||
t.Fatalf("short key error = %v, want ErrKeySize", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupCipherSuccess(t *testing.T) {
|
||||
key := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"
|
||||
c, err := runtime.SetupCipher(key)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupCipher() error = %v", err)
|
||||
}
|
||||
if c == nil {
|
||||
t.Fatal("SetupCipher() returned nil cipher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmuxConfigDefault(t *testing.T) {
|
||||
cfg := runtime.SmuxConfig(0)
|
||||
if cfg.Version != 2 || cfg.MaxFrameSize != 32768 {
|
||||
t.Fatalf("SmuxConfig(0) = %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmuxConfigShrinks(t *testing.T) {
|
||||
// 100-byte wire payload minus crypto overhead is far below default 32768,
|
||||
// so MaxFrameSize must shrink.
|
||||
cfg := runtime.SmuxConfig(100)
|
||||
if cfg.MaxFrameSize >= 32768 {
|
||||
t.Fatalf("MaxFrameSize = %d, want shrunk", cfg.MaxFrameSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTrackerEmitsOnEveryChange(t *testing.T) {
|
||||
var got []control.Status
|
||||
h := runtime.NewHealthTracker(func(s control.Status) {
|
||||
got = append(got, s)
|
||||
})
|
||||
|
||||
h.RecordSession("s1")
|
||||
h.RecordPong(control.Health{LastSeen: time.Unix(100, 0), RTT: time.Millisecond})
|
||||
h.RecordMissed(2)
|
||||
h.RecordReconnect()
|
||||
h.RecordUnhealthy(3)
|
||||
|
||||
if len(got) != 5 {
|
||||
t.Fatalf("notify count = %d, want 5", len(got))
|
||||
}
|
||||
if got[0].SessionID != "s1" {
|
||||
t.Fatalf("first snapshot session id = %q", got[0].SessionID)
|
||||
}
|
||||
if got[1].LastRTT != time.Millisecond {
|
||||
t.Fatalf("second snapshot rtt = %v", got[1].LastRTT)
|
||||
}
|
||||
final := h.Status()
|
||||
if final.Reconnects != 1 || final.UnhealthyEvents != 1 || final.MissedPongs != 3 {
|
||||
t.Fatalf("final snapshot = %+v", final)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTrackerNilNotifyOK(t *testing.T) {
|
||||
h := runtime.NewHealthTracker(nil)
|
||||
h.RecordSession("s") // must not panic
|
||||
if h.Status().SessionID != "s" {
|
||||
t.Fatal("Status() did not record without notify")
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -20,6 +19,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
@@ -27,10 +27,11 @@ import (
|
||||
const connectCommand = "connect"
|
||||
|
||||
var (
|
||||
// ErrKeyRequired is returned when no encryption key is provided.
|
||||
ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with
|
||||
// pre-runtime callers that errors.Is-checked it.
|
||||
ErrKeyRequired = runtime.ErrKeyRequired
|
||||
// ErrKeySize re-exports runtime.ErrKeySize for the same reason.
|
||||
ErrKeySize = runtime.ErrKeySize
|
||||
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
|
||||
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
|
||||
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
|
||||
@@ -62,13 +63,11 @@ type Server struct {
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
authHook handshake.AuthFunc
|
||||
onOpen SessionOpenFunc
|
||||
onClose SessionCloseFunc
|
||||
onTraffic TrafficFunc
|
||||
onHealth HealthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
@@ -76,7 +75,7 @@ type Server struct {
|
||||
socksProxyAddr string
|
||||
socksProxyPort int
|
||||
liveness control.Config
|
||||
health control.Status
|
||||
health *runtime.HealthTracker
|
||||
}
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
@@ -143,22 +142,17 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
if onTraffic == nil {
|
||||
onTraffic = func(string, string, uint64, uint64) {}
|
||||
}
|
||||
onHealth := cfg.OnHealth
|
||||
if onHealth == nil {
|
||||
onHealth = func(control.Status) {}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
authHook: hook,
|
||||
onOpen: onOpen,
|
||||
onClose: onClose,
|
||||
onTraffic: onTraffic,
|
||||
onHealth: onHealth,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
liveness: cfg.Liveness,
|
||||
health: runtime.NewHealthTracker(cfg.OnHealth),
|
||||
}
|
||||
s.setupResolver()
|
||||
|
||||
@@ -189,23 +183,7 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
if keyHex == "" {
|
||||
return nil, ErrKeyRequired
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
return runtime.SetupCipher(keyHex)
|
||||
}
|
||||
|
||||
func (s *Server) setupResolver() {
|
||||
@@ -218,28 +196,12 @@ func (s *Server) setupResolver() {
|
||||
}
|
||||
}
|
||||
|
||||
// smuxConfig mirrors the client side. Both peers must agree on Version and
|
||||
// MaxFrameSize.
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
cfg.KeepAliveTimeout = 60 * time.Second
|
||||
return cfg
|
||||
func smuxConfig(maxWirePayload int) *smux.Config {
|
||||
return runtime.SmuxConfig(maxWirePayload)
|
||||
}
|
||||
|
||||
func linkMaxPayload(tr transport.Transport) int {
|
||||
return tr.Features().MaxPayloadSize
|
||||
return runtime.MaxPayload(tr)
|
||||
}
|
||||
|
||||
func (s *Server) bringUpLink(
|
||||
@@ -548,61 +510,14 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea
|
||||
|
||||
// Status returns the latest server-side control health snapshot.
|
||||
func (s *Server) Status() control.Status {
|
||||
s.healthMu.RLock()
|
||||
defer s.healthMu.RUnlock()
|
||||
return s.health
|
||||
return s.health.Status()
|
||||
}
|
||||
|
||||
func (s *Server) recordSession(sessionID string) {
|
||||
s.healthMu.Lock()
|
||||
s.health.SessionID = sessionID
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordPong(h control.Health) {
|
||||
s.healthMu.Lock()
|
||||
s.health.LastPong = h.LastSeen
|
||||
s.health.LastRTT = h.RTT
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordMissed(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordUnhealthy(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
s.health.UnhealthyEvents++
|
||||
s.health.LastUnhealthy = time.Now()
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordReconnect() {
|
||||
s.healthMu.Lock()
|
||||
s.health.Reconnects++
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) notifyHealth(status control.Status) {
|
||||
if s.onHealth != nil {
|
||||
s.onHealth(status)
|
||||
}
|
||||
}
|
||||
func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) }
|
||||
func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) }
|
||||
func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) }
|
||||
func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) }
|
||||
func (s *Server) recordReconnect() { s.health.RecordReconnect() }
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
s.closeSession()
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
@@ -47,9 +48,9 @@ func TestSetupCipherRejectsBadInput(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSmuxConfig(t *testing.T) {
|
||||
cfg := smuxConfig()
|
||||
cfg := smuxConfig(0)
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
t.Fatalf("smuxConfig(0) = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
@@ -321,12 +322,12 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) {
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
@@ -389,12 +390,12 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
@@ -418,6 +419,7 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
got := make(chan control.Health, 1)
|
||||
s := &Server{
|
||||
sessionID: "sid-control",
|
||||
health: runtime.NewHealthTracker(nil),
|
||||
liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
@@ -463,7 +465,7 @@ func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
s := &Server{onHealth: func(control.Status) { updates++ }}
|
||||
s := &Server{health: runtime.NewHealthTracker(func(control.Status) { updates++ })}
|
||||
s.recordSession("sid-1")
|
||||
s.recordMissed(2)
|
||||
s.recordUnhealthy(3)
|
||||
@@ -504,12 +506,12 @@ func TestDispatchFiresOnTraffic(t *testing.T) {
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
serverSess, err := smux.Server(a, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
clientSess, err := smux.Client(b, smuxConfig(0))
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user