From a86f5c6948c1dddab4a84d59b64729029caf64d7 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Fri, 15 May 2026 23:49:14 +0300 Subject: [PATCH 1/8] feat: add reconnect hardening and failover profiles --- .gitignore | 1 + cmd/olcrtc/main.go | 115 ++++++- cmd/olcrtc/main_test.go | 119 +++++++ docs/about.md | 4 +- docs/client.example.yaml | 1 + docs/configuration.md | 54 +++- docs/failover.example.yaml | 34 ++ docs/manual.md | 2 +- docs/project-map.md | 400 ++++++++++++++++++++++++ docs/server.example.yaml | 1 + docs/settings.md | 20 +- internal/app/session/session.go | 158 ++++++++-- internal/app/session/session_test.go | 112 +++++++ internal/config/config.go | 160 +++++++++- internal/config/config_test.go | 152 +++++++++ internal/e2e/tunnel_test.go | 193 ++++++++++++ internal/engine/livekit/livekit.go | 385 +++++++++++++++++++---- internal/engine/livekit/livekit_test.go | 306 ++++++++++++++++++ internal/supervisor/supervisor.go | 96 ++++++ internal/supervisor/supervisor_test.go | 85 +++++ 20 files changed, 2280 insertions(+), 118 deletions(-) create mode 100644 docs/failover.example.yaml create mode 100644 docs/project-map.md create mode 100644 internal/engine/livekit/livekit_test.go create mode 100644 internal/supervisor/supervisor.go create mode 100644 internal/supervisor/supervisor_test.go diff --git a/.gitignore b/.gitignore index d0b6a3c..61fcb93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Prerequisites *.d +.DS_Store # Object files *.o diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index b8c2bdf..777949b 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -24,6 +24,7 @@ import ( configpkg "github.com/openlibrecommunity/olcrtc/internal/config" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" "github.com/openlibrecommunity/olcrtc/internal/transport/videochannel" ) @@ -35,6 +36,9 @@ var ErrConfigPathRequired = errors.New("usage: olcrtc ") // ErrDataDirRequired is returned when the YAML config does not specify a data directory. var ErrDataDirRequired = errors.New("data directory required (set 'data:' in YAML)") +// ErrProfilesUnsupportedForGen is returned when failover profiles are configured for gen mode. +var ErrProfilesUnsupportedForGen = errors.New("profiles are only supported for srv and cnc modes") + //nolint:gochecknoglobals // Tests replace the long-running session runner with a bounded function. var runSession = session.Run @@ -44,11 +48,18 @@ var runGen = execGen // loadedConfig bundles the parsed YAML file and the derived session config. type loadedConfig struct { scfg session.Config + profiles []supervisor.Profile + failover failoverConfig dataDir string debug bool ffmpegPath string } +type failoverConfig struct { + retryDelay time.Duration + maxCycles int +} + func main() { if err := run(); err != nil { logger.Error(err) @@ -79,14 +90,44 @@ func loadConfig(path string) (loadedConfig, error) { if err != nil { return loadedConfig{}, fmt.Errorf("load config: %w", err) } + base := configpkg.Apply(session.Config{}, f) + profiles := make([]supervisor.Profile, 0, len(f.Profiles)) + for i, profile := range f.Profiles { + name := profile.Name + if name == "" { + name = fmt.Sprintf("profile-%d", i+1) + } + profiles = append(profiles, supervisor.Profile{ + Name: name, + Config: configpkg.ApplyProfile(base, profile), + }) + } + failover, err := parseFailoverConfig(f.Failover) + if err != nil { + return loadedConfig{}, err + } return loadedConfig{ - scfg: configpkg.Apply(session.Config{}, f), + scfg: base, + profiles: profiles, + failover: failover, dataDir: f.Data, debug: f.Debug, ffmpegPath: f.FFmpeg, }, nil } +func parseFailoverConfig(f configpkg.Failover) (failoverConfig, error) { + retryDelay := supervisor.DefaultRetryDelay + if f.RetryDelay != "" { + parsed, err := time.ParseDuration(f.RetryDelay) + if err != nil { + return failoverConfig{}, fmt.Errorf("parse failover.retry_delay: %w", err) + } + retryDelay = parsed + } + return failoverConfig{retryDelay: retryDelay, maxCycles: f.MaxCycles}, nil +} + func runWithConfig(cfg loadedConfig) error { configureLogging(cfg.debug) @@ -98,19 +139,85 @@ func runWithConfig(cfg loadedConfig) error { if err != nil { return fmt.Errorf("validate config: %w", err) } + scfg = session.ApplyTransportDefaults(scfg) if scfg.Mode == modeGen { + if len(cfg.profiles) > 0 { + return ErrProfilesUnsupportedForGen + } return runGen(scfg) } + if len(cfg.profiles) > 0 { + profiles, err := prepareProfiles(cfg.profiles) + if err != nil { + return err + } + return runFailoverSessionMode(cfg.dataDir, profiles, cfg.failover) + } + return runSessionMode(cfg.dataDir, scfg) } +func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error) { + out := make([]supervisor.Profile, 0, len(profiles)) + for _, profile := range profiles { + scfg, err := session.ApplyAuthDefaults(profile.Config) + if err != nil { + return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err) + } + profile.Config = session.ApplyTransportDefaults(scfg) + out = append(out, profile) + } + return out, nil +} + func runSessionMode(dataDir string, scfg session.Config) error { if err := session.Validate(scfg); err != nil { return fmt.Errorf("validate config: %w", err) } + if err := prepareRuntimeData(dataDir); err != nil { + return err + } + + return runManaged(func(ctx context.Context) error { + return runSession(ctx, scfg) + }) +} + +func runFailoverSessionMode(dataDir string, profiles []supervisor.Profile, failover failoverConfig) error { + for _, profile := range profiles { + if err := session.Validate(profile.Config); err != nil { + return fmt.Errorf("validate profile %q: %w", profile.Name, err) + } + } + + if err := prepareRuntimeData(dataDir); err != nil { + return err + } + + return runManaged(func(ctx context.Context) error { + return supervisor.Run(ctx, supervisor.Config{ + Profiles: profiles, + RetryDelay: failover.retryDelay, + MaxCycles: failover.maxCycles, + OnProfileStart: func(profile supervisor.Profile, cycle int) { + logger.Infof("failover cycle=%d starting profile=%s carrier=%s transport=%s", + cycle, profile.Name, profile.Config.Auth, profile.Config.Transport) + }, + OnProfileEnd: func(profile supervisor.Profile, cycle int, err error) { + if err != nil { + logger.Warnf("failover cycle=%d profile=%s ended with error: %v", cycle, profile.Name, err) + return + } + logger.Warnf("failover cycle=%d profile=%s ended", cycle, profile.Name) + }, + }, runSession) + }) +} + +func prepareRuntimeData(dataDir string) error { if dataDir == "" { return ErrDataDirRequired } @@ -124,6 +231,10 @@ func runSessionMode(dataDir string, scfg session.Config) error { return err } + return nil +} + +func runManaged(run func(context.Context) error) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -132,7 +243,7 @@ func runSessionMode(dataDir string, scfg session.Config) error { errCh := make(chan error, 1) go func() { - errCh <- runSession(ctx, scfg) + errCh <- run(ctx) }() select { diff --git a/cmd/olcrtc/main_test.go b/cmd/olcrtc/main_test.go index acb6a1d..96a4aeb 100644 --- a/cmd/olcrtc/main_test.go +++ b/cmd/olcrtc/main_test.go @@ -9,6 +9,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/logger" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" ) var errBoom = errors.New("boom") @@ -149,6 +150,112 @@ data: `+dir+` } } +func TestRunWithArgsAppliesTransportDefaults(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil { + t.Fatalf("WriteFile(names) error = %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil { + t.Fatalf("WriteFile(surnames) error = %v", err) + } + + oldRunSession := runSession + t.Cleanup(func() { runSession = oldRunSession }) + runSession = func(ctx context.Context, cfg session.Config) error { + if cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1 { + t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize) + } + return nil + } + + yamlPath := writeYAML(t, ` +mode: srv +link: direct +auth: + provider: wbstream +room: + id: room +crypto: + key: key +net: + transport: vp8channel + dns: 1.1.1.1:53 +data: `+dir+` +`) + + if err := runWithArgs([]string{yamlPath}); err != nil { + t.Fatalf("runWithArgs() error = %v", err) + } +} + +func TestRunWithArgsFailoverProfiles(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil { + t.Fatalf("WriteFile(names) error = %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil { + t.Fatalf("WriteFile(surnames) error = %v", err) + } + + oldRunSession := runSession + t.Cleanup(func() { runSession = oldRunSession }) + var seen []string + runSession = func(ctx context.Context, cfg session.Config) error { + seen = append(seen, cfg.Auth+"/"+cfg.Transport) + if cfg.Auth == "wbstream" && (cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1) { + t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize) + } + return errBoom + } + + yamlPath := writeYAML(t, ` +mode: srv +link: direct +crypto: + key: key +net: + dns: 1.1.1.1:53 +profiles: + - name: wb-primary + auth: + provider: wbstream + room: + id: room + net: + transport: vp8channel + - name: jitsi-backup + auth: + provider: jitsi + room: + id: https://meet.example/room + net: + transport: datachannel +failover: + retry_delay: -1ns + max_cycles: 1 +data: `+dir+` +`) + + err := runWithArgs([]string{yamlPath}) + if !errors.Is(err, supervisor.ErrMaxCyclesExceeded) { + t.Fatalf("runWithArgs() error = %v, want %v", err, supervisor.ErrMaxCyclesExceeded) + } + want := []string{"wbstream/vp8channel", "jitsi/datachannel"} + if !equalStrings(seen, want) { + t.Fatalf("seen profiles = %v, want %v", seen, want) + } +} + +func TestRunWithConfigRejectsProfilesInGenMode(t *testing.T) { + cfg := loadedConfig{ + scfg: session.Config{Mode: modeGen}, + profiles: []supervisor.Profile{{Name: "one"}}, + } + if err := runWithConfig(cfg); !errors.Is(err, ErrProfilesUnsupportedForGen) { + t.Fatalf("runWithConfig() error = %v, want %v", err, ErrProfilesUnsupportedForGen) + } +} + func TestConfigureLogging(t *testing.T) { t.Setenv("PION_LOG_DISABLE", "") logger.SetVerbose(false) @@ -170,6 +277,18 @@ func TestConfigureLogging(t *testing.T) { } } +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + func TestResolveDataDir(t *testing.T) { abs := filepath.Join(t.TempDir(), "data") got, err := resolveDataDir(abs) diff --git a/docs/about.md b/docs/about.md index 112c6bb..a67149d 100644 --- a/docs/about.md +++ b/docs/about.md @@ -234,7 +234,7 @@ internal/e2e/ E2E тесты на реальных провайдер | Файл | Что делает | |---|---| -| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для jazz с ретраями (wbstream больше не поддерживает автогенерацию - руму нужно создавать вручную через stream.wb.ru) | +| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для auth-провайдеров с `RoomCreator` и ретраями | | `session_test.go` | Тесты валидации конфига | ### `internal/config/` @@ -452,7 +452,7 @@ Carrier - это WebRTC сервис видеозвонков, через кот - Минимальная прослойка, почти прямой relay - Работает с vp8channel, seichannel, videochannel - DataChannel **не работает** в обычном guest flow: WB Stream выдаёт токены с `canPublishData=false`, DC не маршрутизирует данные (expected fail в E2E тестах) -- Room ID нужно создавать вручную через stream.wb.ru +- Room ID можно создать вручную через stream.wb.ru или через `mode: gen` - Инициализация звонка автоматически --- diff --git a/docs/client.example.yaml b/docs/client.example.yaml index 5ec8792..fe83e0d 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -14,6 +14,7 @@ room: id: "https://meet.cryptopro.ru/REPLACE_WITH_ROOM_NAME" crypto: + # Or use key_file: "./olcrtc.key" to keep the secret out of this file. key: "REPLACE_ME_WITH_64_HEX_CHARS" # must match the server net: diff --git a/docs/configuration.md b/docs/configuration.md index 97d77fd..46edd07 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -11,6 +11,7 @@ olcrtc /etc/olcrtc/server.yaml - [`server.example.yaml`](./server.example.yaml) - [`client.example.yaml`](./client.example.yaml) +- [`failover.example.yaml`](./failover.example.yaml) ## Схема @@ -20,7 +21,7 @@ olcrtc /etc/olcrtc/server.yaml | `link` | `direct` | | `auth.provider` | `jitsi`, `telemost`, `jazz`, `wbstream`, `none` | | `room.id` | conference room id | -| `crypto.key` | 64-char hex (32 bytes) | +| `crypto.key` / `crypto.key_file` | 64-char hex (32 bytes), inline or read from file | | `net.transport` | `datachannel`, `videochannel`, `seichannel`, `vp8channel` | | `net.dns` | resolver `host:port` | | `socks.host` / `.port` | client-side listener | @@ -31,6 +32,57 @@ olcrtc /etc/olcrtc/server.yaml | `vp8.*` | vp8channel tuning | | `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning | | `gen.amount` | gen mode: number of rooms to create | +| `profiles[]` | ordered srv/cnc failover profiles | +| `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | +| `failover.max_cycles` | stop after N full profile-list passes; `0` = forever | | `data` | path to data directory | | `debug` | verbose logging | | `ffmpeg` | path to ffmpeg binary | + +`mode: cnc` refuses non-loopback `socks.host` values unless both +`socks.user` and `socks.pass` are set. + +`crypto.key_file` is resolved relative to the YAML file. Do not set it +together with `crypto.key`. + +## Failover Profiles + +`mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used +as common defaults; each profile overrides only the fields it sets. The CLI +runs profiles in order. If a profile fails or ends while the process is still +alive, olcrtc waits `failover.retry_delay` and starts the next profile. + +```yaml +mode: srv +link: direct +crypto: + key_file: ./olcrtc.key +net: + dns: "1.1.1.1:53" +data: data + +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: "WB_ROOM_ID" + net: + transport: vp8channel + + - name: jitsi-dc + auth: + provider: jitsi + room: + id: "https://meet.example.org/olcrtc-room" + net: + transport: datachannel + +failover: + retry_delay: 2s + max_cycles: 0 +``` + +Both peers must use compatible profile order and room settings. This first +failover layer rebuilds the session on the next profile; active smux streams +do not migrate, but new connections can recover on the next profile. diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml new file mode 100644 index 0000000..7aa8149 --- /dev/null +++ b/docs/failover.example.yaml @@ -0,0 +1,34 @@ +# olcrtc failover config example +# Use the same profile order on both peers. + +mode: srv +link: direct + +crypto: + key_file: "./olcrtc.key" + +net: + dns: "1.1.1.1:53" + +data: data + +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: "REPLACE_WITH_WB_ROOM_ID" + net: + transport: vp8channel + + - name: jitsi-datachannel + auth: + provider: jitsi + room: + id: "https://meet.example.org/REPLACE_WITH_ROOM_NAME" + net: + transport: datachannel + +failover: + retry_delay: 2s + max_cycles: 0 diff --git a/docs/manual.md b/docs/manual.md index a2a3a21..d623d86 100644 --- a/docs/manual.md +++ b/docs/manual.md @@ -177,7 +177,7 @@ data: data ### wbstream + vp8channel (альтернатива) -Сначала создай руму вручную через сайт [wbstream](https://stream.wb.ru) (автогенерация через `mode: gen` для wbstream больше не поддерживается) и сохрани её ID. +Создай руму через сайт [wbstream](https://stream.wb.ru) или заранее сгенерируй ID через `mode: gen` с `auth.provider: wbstream`. `wbstream + datachannel` **не работает** в обычном guest flow — WB Stream выдаёт токены с `canPublishData=false`, и DC не маршрутизирует данные. Для обычного использования выбирай `vp8channel`. diff --git a/docs/project-map.md b/docs/project-map.md new file mode 100644 index 0000000..c4c8791 --- /dev/null +++ b/docs/project-map.md @@ -0,0 +1,400 @@ +# olcRTC Project Map + +This is a developer map for finding the useful parts of the project quickly. +It focuses on code ownership, runtime flow, extension points, and areas that +are worth deeper work. + +## One-Sentence Model + +olcRTC is an encrypted TCP-over-WebRTC tunnel: the client exposes a local +SOCKS5 listener, the server dials requested TCP targets, and both sides carry +the smux byte stream through a selected WebRTC carrier and transport. + +## Runtime Stack + +```text +YAML config + -> cmd/olcrtc + -> internal/config + -> internal/app/session + -> internal/server or internal/client + -> internal/link/direct + -> internal/transport/{datachannel,vp8channel,seichannel,videochannel} + -> internal/carrier/builtin + -> internal/auth/ + internal/engine/ + -> external service SFU / signaling +``` + +Tunnel data path: + +```text +local app + -> client SOCKS5 + -> smux stream + -> muxconn AEAD encrypt + -> link.Send + -> transport encoding + -> carrier/engine + -> SFU/service + -> peer engine/carrier + -> transport decoding + -> muxconn AEAD decrypt + -> smux stream + -> server TCP dial + -> target host +``` + +## Entrypoints + +| Path | Purpose | +|---|---| +| `cmd/olcrtc/main.go` | Main CLI. Accepts one YAML file, applies auth and transport defaults, starts `srv`, `cnc`, or `gen`. | +| `cmd/olcrtc-cgo/main.go` | Small c-shared entrypoint for desktop/native consumers. | +| `pkg/olcrtc` | Embeddable lower-level API that returns a `net.Conn`-like handle over an engine data path. | +| `pkg/olcrtc/tunnel` | Embeddable server-side tunnel API with auth and traffic hooks. | +| `mobile/mobile.go` | gomobile API for Android clients, including VPN socket protection. | +| `script/srv.sh`, `script/cnc.sh` | Interactive shell launchers that generate YAML and run/build the app. | +| `Dockerfile`, `script/docker/*` | Container build and server entrypoint/healthcheck. | + +## Config And Session Layer + +`internal/config` owns YAML parsing and file-backed secret loading. + +Important fields: + +| YAML | Runtime field | Notes | +|---|---|---| +| `mode` | `session.Config.Mode` | `srv`, `cnc`, or `gen`. | +| `auth.provider` | `Auth` | `jitsi`, `telemost`, `jazz`, `wbstream`, or `none`. | +| `room.id` | `RoomID` | Carrier-specific room reference. | +| `crypto.key` / `crypto.key_file` | `KeyHex` | Shared 32-byte key encoded as 64 hex chars. | +| `net.transport` | `Transport` | `datachannel`, `vp8channel`, `seichannel`, or `videochannel`. | +| `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. | +| `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | +| `engine.*` | direct engine fields | Used only with `auth.provider: none`. | + +`internal/app/session` is the main router: + +1. Registers built-ins via `RegisterDefaults`. +2. Applies auth defaults: auth provider decides engine and default service URL. +3. Applies transport defaults: documented defaults for `vp8`, `sei`, and `video`. +4. Validates mode, auth, link, transport, room, key, DNS, transport options, and SOCKS listener safety. +5. Runs `server.Run`, `client.Run`, or `Gen`. + +## Server Side + +`internal/server` accepts encrypted smux sessions from the peer and proxies +each smux stream to a TCP target. + +Core pieces: + +| Symbol | Role | +|---|---| +| `server.Run` | Creates cipher, link, smux server, and serve loop. | +| `bringUpLink` | Builds `link.Link`, wires reconnect callbacks, connects carrier. | +| `installSession` / `reinstallSession` | Creates or replaces `muxconn + smux.Session`. | +| `acceptHandshake` | First smux stream; runs `handshake.Server`. | +| `handleStream` | Reads connect JSON and dispatches a tunnel stream. | +| `dispatch` | Dials target, sends ready byte, copies both directions. | +| `AuthHook` | Embedders can authorize clients after `CLIENT_HELLO`. | +| `OnSessionOpen`, `OnSessionClose`, `OnTraffic` | Observability hooks. | + +Server risk areas: + +- Target dialing is powerful by design. Any real product wrapper should add + an `AuthHook` and probably destination policy. +- `defaultAuthHook` admits everyone who knows the room and key. +- Reconnect rebuilds smux sessions; active streams are sacrificed. + +## Client Side + +`internal/client` exposes a local SOCKS5 listener and opens one smux stream +per SOCKS CONNECT request. + +Core pieces: + +| Symbol | Role | +|---|---| +| `RunWithReady` | Starts link, opens smux client, listens on local SOCKS. | +| `openControlStream` | First smux stream; runs `handshake.Client`. | +| `handleSocks5` | SOCKS method negotiation and CONNECT parsing. | +| `sendConnectRequest` | Sends server-side target JSON and waits for ready byte. | +| `handleReconnect` | Rebuilds smux and control stream after carrier reconnect. | +| `resolveDeviceID` | Optional persistent client identity for hooks. | + +Client risk areas: + +- A non-loopback SOCKS listener must require `socks.user` and `socks.pass`. +- SOCKS credentials are simple static credentials, not a full account system. +- Existing streams do not survive reconnect; new SOCKS connections can recover. + +## Wire Protocol Above WebRTC + +`internal/muxconn` adapts `link.Link` to `io.ReadWriteCloser`. + +- Every smux write is encrypted with `internal/crypto`. +- Every inbound link message is decrypted and appended to an internal byte buffer. +- Bad AEAD frames are dropped. +- `CanSend` provides backpressure before encrypting and sending. + +`internal/crypto` uses XChaCha20-Poly1305 with a random nonce prepended to +each ciphertext. + +`internal/handshake` runs on the first smux stream: + +```text +CLIENT_HELLO { version, device_id, claims } +SERVER_WELCOME { version, session_id } +or +SERVER_REJECT { version, reason } +``` + +The handshake has a 64 KiB frame cap and a default 15 second timeout. + +## Registries And Plugin Shape + +The universal-carrier refactor centers on small registries: + +| Registry | Package | Registers | +|---|---|---| +| Auth providers | `internal/auth` | Service-specific credential and room creation flows. | +| Engines | `internal/engine` | Wire-level SFU protocol implementations. | +| Carriers | `internal/carrier` | Auth + engine adapters exposed as byte/video capability providers. | +| Transports | `internal/transport` | Byte transport strategy over carrier primitives. | +| Links | `internal/link` | Higher-level link abstraction; currently only `direct`. | + +`internal/carrier/builtin` connects the auth and engine worlds: + +```text +carrier "wbstream" -> auth/wbstream -> engine/livekit +carrier "jazz" -> auth/salutejazz -> engine/salutejazz +carrier "telemost"-> auth/telemost -> engine/goolom +carrier "jitsi" -> auth/jitsi -> engine/jitsi +carrier "none" -> direct user-supplied engine/url/token +``` + +## Auth Providers + +| Provider | Engine | Room generation | Notes | +|---|---|---:|---| +| `jitsi` | `jitsi` | No | Parses host/room from a public or self-hosted Jitsi URL. No HTTP auth. | +| `telemost` | `goolom` | No | Calls Telemost room-info flow and returns Goolom credentials. | +| `wbstream` | `livekit` | Yes | Registers guest, optionally creates room, joins room, fetches LiveKit token. | +| `jazz` / `salutejazz` | `salutejazz` | Yes | Creates or joins SaluteJazz room and returns room/password tuple. | +| `none` | chosen by config | No | Direct engine mode for downstream tools or self-hosted SFUs. | + +## Engines + +Engines expose the low-level service/SFU protocol. + +| Engine | Package | Byte stream | Video track | Main job | +|---|---|---:|---:|---| +| `livekit` | `internal/engine/livekit` | Yes | Yes | LiveKit SDK room, data packets, local/remote tracks, reconnect with credential refresh. | +| `goolom` | `internal/engine/goolom` | Yes | Yes | Yandex Telemost/Goolom signaling, split publisher/subscriber peer connections, telemetry/keepalive. | +| `jitsi` | `internal/engine/jitsi` | Yes | Best effort | Jitsi MUC/Jingle/colibri-ws plus optional video track negotiation. | +| `salutejazz` | `internal/engine/salutejazz` | Yes | Yes | SaluteJazz WebSocket signaling and split media peer connections. | + +Engine work is where most provider breakage and reconnect complexity lives. + +## Transports + +Transports decide how raw tunnel bytes are carried once the carrier provides +either a byte stream or a video track. + +| Transport | Primitive | Reliability model | Best fit | Notes | +|---|---|---|---|---| +| `datachannel` | Carrier byte stream | Native reliable ordered messages | Jitsi, direct engines, some Jazz cases | Simple pass-through with 12 KiB message cap. | +| `vp8channel` | VP8 video track | KCP over VP8-looking frames | WB Stream and Telemost-style video paths | Highest-performance video-path transport. Uses epochs and binding tokens to survive restarts/loopback. | +| `seichannel` | H264 SEI video track | Custom fragments + ACK/retry | WB Stream fallback | Carries data in SEI NAL units with fragmentation, CRC, ACK. | +| `videochannel` | Visual frames via ffmpeg | QR/tile frames + ACK/retry | Experimental/inspection-friendly path | Encodes visual payload frames, requires ffmpeg, supports QR and tile codecs. | + +Transport work is where throughput, loss recovery, and adaptive tuning should +happen. + +## Public/Embedding Surfaces + +| Package | User | +|---|---| +| `pkg/olcrtc` | Go programs that want a `net.Conn` over a selected auth/engine. | +| `pkg/olcrtc/tunnel` | Go programs that want to embed the server-side tunnel with auth/traffic hooks. | +| `mobile` | Android app bindings. Wraps client mode, VPN socket protection, logging, simple health checks. | +| `cmd/olcrtc-cgo` | Native desktop/client integrations using c-shared Go export. | + +These surfaces are important if the CLI becomes only one frontend among many. + +## Tests + +The project has broad unit coverage: + +- Config/session validation and defaults. +- Auth provider HTTP flows with test servers. +- Engine helper logic and reconnect paths. +- SOCKS parsing, smux handshake, server dispatch. +- Crypto, muxconn, names, protect, logging. +- Transport frame codecs, ACK paths, KCP loopback, ffmpeg helpers. +- Memory-backed E2E tunnel tests and optional real-provider E2E matrix. + +Useful commands: + +```sh +go test -count=1 ./... +go test -race -count=1 ./cmd/olcrtc ./internal/app/session ./internal/config ./internal/engine/livekit +go test -race -count=1 -v ./internal/e2e +E2E_CARRIERS=wbstream E2E_TRANSPORTS=vp8channel mage e2e +go build -trimpath -o build/olcrtc ./cmd/olcrtc +``` + +## High-Value Coding Areas + +### 1. Supervisor And Multi-Profile Failover + +The first supervisor layer exists in `internal/supervisor`: the CLI can run a +prioritized list of carrier/transport profiles and move to the next profile +when the active one fails or ends. + +```yaml +mode: srv +link: direct +crypto: + key_file: ./olcrtc.key +net: + dns: "1.1.1.1:53" +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: WB_ROOM_ID + net: + transport: vp8channel + - name: jitsi-dc + auth: + provider: jitsi + room: + id: https://meet.example.org/olcrtc-room + net: + transport: datachannel +failover: + retry_delay: 2s + max_cycles: 0 +``` + +Implemented: + +- Config schema for `profiles[]`. +- Ordered supervisor loop. +- `failover.retry_delay`. +- `failover.max_cycles`. +- Profile start/end logs. + +Still valuable: + +- Health scoring per profile. +- Control-stream coordination before switching. +- Stream draining and migration instead of dropping active smux streams. +- Shared status output for the active profile and failover history. + +Likely files: + +- `internal/config/config.go` +- `internal/app/session/session.go` +- `internal/supervisor` +- `internal/server` +- `internal/client` +- `docs/configuration.md` +- `internal/e2e/tunnel_test.go` + +### 2. Transport Telemetry And Adaptive Tuning + +Add metrics from transport to link/session: + +- Send queue depth. +- ACK latency. +- Retries. +- Reconnect count. +- Dropped/decrypt-failed frames. +- KCP RTT/loss where available. + +Then make `vp8.batch_size`, `sei.fragment_size`, ACK timeout, and pacing +adaptive instead of static YAML knobs. + +### 3. Control Stream Protocol + +The first smux stream is parked after handshake. It is the natural place for: + +- Ping/pong and peer liveness. +- Server policy updates. +- Graceful reconnect notifications. +- Drain/start markers for failover. +- Per-session stats. + +Likely files: + +- `internal/handshake` +- `internal/server` +- `internal/client` + +### 4. Destination Policy And Real Auth + +The tunnel can dial arbitrary server-side TCP targets. A production wrapper +should use `AuthHook` and enforce: + +- Allowed destination CIDRs/domains/ports. +- Per-device or per-plan policy. +- Session expiration. +- Traffic accounting limits. +- Sanitized rejection reasons. + +This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`. + +### 5. Provider Hardening + +Provider APIs can drift. Worth adding: + +- Better typed errors from auth providers. +- Provider health probes. +- Fixture-based contract tests for API response changes. +- Per-provider rate/backoff policy. +- Safer secret/log redaction. + +Likely files: + +- `internal/auth/*` +- `internal/engine/*` +- `internal/carrier/builtin` + +### 6. Codebase Hygiene + +Some public-facing text and comments are not suitable for a serious external +project. Cleaning that up would improve maintainability and downstream trust. +The most obvious targets are top-level docs and a large hostile block comment +in `internal/transport/vp8channel/transport.go`. + +## Where To Look First + +| Goal | Start here | +|---|---| +| Change YAML schema | `internal/config/config.go`, `cmd/olcrtc/main.go`, docs examples. | +| Change validation/defaults | `internal/app/session/session.go`. | +| Add a new auth provider | `internal/auth`, then register in `internal/carrier/builtin/register.go`. | +| Add a new SFU protocol | `internal/engine`, then connect through auth/carrier. | +| Add a new byte transport | `internal/transport`, then register in `session.RegisterDefaults`. | +| Add link behavior above transports | `internal/link`; currently only `direct`. | +| Improve SOCKS behavior | `internal/client`. | +| Improve server target dialing or policy | `internal/server`, `pkg/olcrtc/tunnel`. | +| Improve reconnect | Engines first, then `internal/client` and `internal/server` smux rebuild behavior. | +| Improve Android app integration | `mobile`, `internal/protect`, `client.RunWithReady`. | + +## Mental Model For Big Changes + +Prefer to keep the layer boundaries: + +- Auth creates credentials; it should not know transport details. +- Engine speaks service/SFU protocol; it should not know SOCKS or smux. +- Carrier adapts auth+engine into byte/video capabilities. +- Transport turns byte/video capabilities into reliable-ish tunnel bytes. +- Link is policy above transport. +- Client/server own SOCKS, smux, handshake, target dialing, and session hooks. + +If a change crosses more than two layers, it probably deserves a new +orchestrator package instead of pushing more state into an engine or transport. diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 7a5f638..9f5ee38 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -16,6 +16,7 @@ room: crypto: # 32-byte hex (64 chars). Generate with: openssl rand -hex 32 + # Or use key_file: "./olcrtc.key" to keep the secret out of this file. key: "REPLACE_ME_WITH_64_HEX_CHARS" net: diff --git a/docs/settings.md b/docs/settings.md index f9d6c80..28855ce 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -48,7 +48,7 @@ | `auth.provider` | `telemost`, `jazz`, `wbstream` или `jitsi` | | `net.transport` | `datachannel`, `vp8channel`, `seichannel` или `videochannel` | | `room.id` | Room ID | -| `crypto.key` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` | +| `crypto.key` или `crypto.key_file` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` | | `link` | Всегда `direct` | | `data` | Всегда `data` | | `net.dns` | DNS-сервер, например `1.1.1.1:53` | @@ -60,18 +60,27 @@ | YAML поле | Описание | |-----------|----------| | `debug` | `true` для подробных логов соединений | +| `profiles` | Список профилей failover для `srv`/`cnc` | +| `failover.retry_delay` | Пауза перед следующим профилем, например `2s` | +| `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно | + +`crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. + +Если задан `profiles`, поля верхнего уровня становятся общими defaults, а +каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и +настройки транспорта. Порядок профилей должен совпадать на сервере и клиенте. --- ## mode: gen -Генерирует Room ID заранее, не запуская сервер. Поддерживается только для `jazz`. Для `wbstream` создавай руму вручную через [stream.wb.ru](https://stream.wb.ru) (автогенерация отключена со стороны WB). +Генерирует Room ID заранее, не запуская сервер. Поддерживается для auth-провайдеров с автосозданием комнат: `jazz` и `wbstream`. Для `telemost` комнату нужно создавать вручную через сайт. **Обязательные поля:** | YAML поле | Описание | |-----------|----------| -| `auth.provider` | `jazz` | +| `auth.provider` | `jazz` или `wbstream` | | `net.dns` | DNS-сервер | | `gen.amount` | Количество комнат | @@ -79,7 +88,7 @@ # gen.yaml mode: gen auth: - provider: jazz + provider: wbstream net: dns: "1.1.1.1:53" gen: @@ -116,6 +125,9 @@ gen: Если `socks.user` не задан - аутентификация отключена (любой локальный клиент может подключиться). Если задан - клиент принимает только подключения с правильным логином и паролем (RFC 1929). +Если `socks.host` не loopback (`127.0.0.1`, `::1`, `localhost`), `socks.user` и `socks.pass` обязательны. +Это защита от случайного открытого SOCKS5-прокси в локальной сети или интернете. + --- ## datachannel diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 89900bf..89de5f5 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "slices" "time" @@ -37,18 +38,33 @@ const ( videoCodecTile = "tile" ) +const ( + defaultVideoWidth = 1920 + defaultVideoHeight = 1080 + defaultVideoFPS = 30 + defaultVideoBitrate = "2M" + defaultVideoHW = "none" + defaultVideoQRRecovery = "low" + defaultVP8FPS = 25 + defaultVP8BatchSize = 1 + defaultSEIFPS = 60 + defaultSEIBatchSize = 64 + defaultSEIFragmentSize = 900 + defaultSEIAckTimeoutMS = 2000 +) + var ( // ErrRoomIDRequired indicates that a room id is required for the selected carrier. - ErrRoomIDRequired = errors.New("room ID required (use -id )") + ErrRoomIDRequired = errors.New("room ID required (set room.id)") // ErrModeRequired indicates that mode is not one of the supported values. - ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)") - // ErrAmountRequired indicates that -amount is required for gen mode. - ErrAmountRequired = errors.New("amount required for gen mode (use -amount )") + ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)") + // ErrAmountRequired indicates that gen.amount is required for gen mode. + ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)") // ErrAuthRequired indicates that no auth provider was selected. ErrAuthRequired = errors.New( - "auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)") - // ErrURLRequired indicates that -url must be provided when the auth provider has no default URL. - ErrURLRequired = errors.New("SFU URL required (use -url wss://...)") + "auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)") + // ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL. + ErrURLRequired = errors.New("SFU URL required (set auth.url)") // ErrUnsupportedCarrier indicates that carrier is not registered. ErrUnsupportedCarrier = errors.New("unsupported carrier") // ErrUnsupportedLink indicates that link is not registered. @@ -57,51 +73,53 @@ var ( ErrUnsupportedTransport = errors.New("unsupported transport") // ErrLinkRequired indicates that link is not provided. - ErrLinkRequired = errors.New("link required (use -link direct)") + ErrLinkRequired = errors.New("link required (set link to direct)") // ErrTransportRequired indicates that transport is not provided. ErrTransportRequired = errors.New( - "transport required (use -transport datachannel, -transport videochannel, " + - "-transport seichannel or -transport vp8channel)") + "transport required (set transport to datachannel, videochannel, seichannel or vp8channel)") // ErrKeyRequired indicates that encryption key is not provided. - ErrKeyRequired = errors.New("key required (use -key )") + ErrKeyRequired = errors.New("key required (set crypto.key)") // ErrDNSServerRequired indicates that dns server is not provided. - ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)") + ErrDNSServerRequired = errors.New("dns server required (set net.dns)") // ErrVideoWidthRequired indicates that video width is required for videochannel. - ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)") + ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)") // ErrVideoHeightRequired indicates that video height is required for videochannel. - ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)") + ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)") // ErrVideoFPSRequired indicates that video fps is required for videochannel. - ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)") + ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)") // ErrVideoBitrateRequired indicates that video bitrate is required for videochannel. ErrVideoBitrateRequired = errors.New( - "video bitrate required for videochannel (use -video-bitrate)") + "video bitrate required for videochannel (set video.bitrate)") // ErrVideoHWRequired indicates that video hardware acceleration is required. ErrVideoHWRequired = errors.New( - "video hardware acceleration required for videochannel (use -video-hw none/nvenc)") + "video hardware acceleration required for videochannel (set video.hw to none or nvenc)") // ErrVideoCodecInvalid indicates that the video codec is not valid. ErrVideoCodecInvalid = errors.New( - "invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)") + "invalid video codec for videochannel (set video.codec to qrcode or tile)") // ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions. - ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080") + ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080") // ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel. - ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)") + ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)") // ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel. - ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)") + ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)") // ErrSEIFPSRequired indicates that seichannel fps is required. - ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)") + ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)") // ErrSEIBatchSizeRequired indicates that seichannel batch size is required. - ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)") + ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)") // ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required. - ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)") + ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)") // ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required. - ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)") + ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)") // ErrSOCKSHostRequired indicates that socks host is required for cnc mode. - ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)") + ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)") // ErrSOCKSPortRequired indicates that socks port is required for cnc mode. - ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)") + ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)") + // ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication. + ErrSOCKSAuthRequired = errors.New( + "socks auth required when binding outside loopback (set socks.user and socks.pass)") ) // Config holds runtime session settings. @@ -180,6 +198,80 @@ func ApplyAuthDefaults(cfg Config) (Config, error) { return cfg, nil } +// ApplyTransportDefaults fills documented transport defaults without changing core routing fields. +func ApplyTransportDefaults(cfg Config) Config { + switch cfg.Transport { + case transportVideo: + return applyVideoDefaults(cfg) + case transportVP8: + return applyVP8Defaults(cfg) + case transportSEI: + return applySEIDefaults(cfg) + default: + return cfg + } +} + +func applyVideoDefaults(cfg Config) Config { + if cfg.VideoCodec == "" { + cfg.VideoCodec = videoCodecQRCode + } + if cfg.VideoCodec == videoCodecTile { + if cfg.VideoWidth == 0 { + cfg.VideoWidth = 1080 + } + if cfg.VideoHeight == 0 { + cfg.VideoHeight = 1080 + } + } else { + if cfg.VideoWidth == 0 { + cfg.VideoWidth = defaultVideoWidth + } + if cfg.VideoHeight == 0 { + cfg.VideoHeight = defaultVideoHeight + } + } + if cfg.VideoFPS == 0 { + cfg.VideoFPS = defaultVideoFPS + } + if cfg.VideoBitrate == "" { + cfg.VideoBitrate = defaultVideoBitrate + } + if cfg.VideoHW == "" { + cfg.VideoHW = defaultVideoHW + } + if cfg.VideoQRRecovery == "" { + cfg.VideoQRRecovery = defaultVideoQRRecovery + } + return cfg +} + +func applyVP8Defaults(cfg Config) Config { + if cfg.VP8FPS == 0 { + cfg.VP8FPS = defaultVP8FPS + } + if cfg.VP8BatchSize == 0 { + cfg.VP8BatchSize = defaultVP8BatchSize + } + return cfg +} + +func applySEIDefaults(cfg Config) Config { + if cfg.SEIFPS == 0 { + cfg.SEIFPS = defaultSEIFPS + } + if cfg.SEIBatchSize == 0 { + cfg.SEIBatchSize = defaultSEIBatchSize + } + if cfg.SEIFragmentSize == 0 { + cfg.SEIFragmentSize = defaultSEIFragmentSize + } + if cfg.SEIAckTimeoutMS == 0 { + cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS + } + return cfg +} + // Validate verifies that the runtime config refers to registered components and all required fields are present. func Validate(cfg Config) error { if err := validateMode(cfg); err != nil { @@ -333,11 +425,23 @@ func validateModeConfig(cfg Config) error { if cfg.SOCKSPort == 0 { return ErrSOCKSPortRequired } + if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") { + return ErrSOCKSAuthRequired + } return nil } +func isLoopbackListenHost(host string) bool { + if host == "localhost" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + // Run starts the configured mode. func Run(ctx context.Context, cfg Config) error { + cfg = ApplyTransportDefaults(cfg) roomURL := cfg.RoomID switch cfg.Mode { diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index 6ca3f79..f20e70d 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -6,6 +6,85 @@ import ( "testing" ) +func TestApplyTransportDefaults(t *testing.T) { + tests := []struct { + name string + in Config + want Config + }{ + { + name: "vp8", + in: Config{Transport: transportVP8}, + want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1}, + }, + { + name: "sei", + in: Config{Transport: transportSEI}, + want: Config{ + Transport: transportSEI, + SEIFPS: 60, + SEIBatchSize: 64, + SEIFragmentSize: 900, + SEIAckTimeoutMS: 2000, + }, + }, + { + name: "video qrcode", + in: Config{Transport: transportVideo}, + want: Config{ + Transport: transportVideo, + VideoWidth: 1920, + VideoHeight: 1080, + VideoFPS: 30, + VideoBitrate: "2M", + VideoHW: "none", + VideoQRRecovery: "low", + VideoCodec: videoCodecQRCode, + }, + }, + { + name: "video tile dimensions", + in: Config{Transport: transportVideo, VideoCodec: videoCodecTile}, + want: Config{ + Transport: transportVideo, + VideoWidth: 1080, + VideoHeight: 1080, + VideoFPS: 30, + VideoBitrate: "2M", + VideoHW: "none", + VideoQRRecovery: "low", + VideoCodec: videoCodecTile, + }, + }, + { + name: "keeps explicit values", + in: Config{ + Transport: transportSEI, + SEIFPS: 10, + SEIBatchSize: 2, + SEIFragmentSize: 300, + SEIAckTimeoutMS: 1500, + }, + want: Config{ + Transport: transportSEI, + SEIFPS: 10, + SEIBatchSize: 2, + SEIFragmentSize: 300, + SEIAckTimeoutMS: 1500, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ApplyTransportDefaults(tt.in) + if got != tt.want { + t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want) + } + }) + } +} + //nolint:maintidx // table-driven validation test naturally has many cases func TestValidate(t *testing.T) { RegisterDefaults() @@ -310,6 +389,39 @@ func TestValidate(t *testing.T) { }(), want: ErrSOCKSPortRequired, }, + { + name: "cnc rejects unauthenticated wildcard socks bind", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "0.0.0.0" + cfg.SOCKSPort = 1080 + return cfg + }(), + want: ErrSOCKSAuthRequired, + }, + { + name: "cnc allows authenticated wildcard socks bind", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "0.0.0.0" + cfg.SOCKSPort = 1080 + cfg.SOCKSUser = "user" + cfg.SOCKSPass = "pass" + return cfg + }(), + }, + { + name: "cnc allows localhost socks bind without auth", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "localhost" + cfg.SOCKSPort = 1080 + return cfg + }(), + }, } for _, tt := range tests { diff --git a/internal/config/config.go b/internal/config/config.go index 49b0f60..5fe206c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,9 @@ // Package config loads olcrtc runtime configuration from YAML files. // // The YAML schema mirrors [session.Config]. Fields left unset in the file -// remain at their zero value, allowing CLI flags to fill them in. Use -// [Apply] to merge a parsed [File] onto an existing [session.Config]; -// non-zero fields in the session config (typically populated from CLI flags) -// take precedence over the YAML values. +// remain at their zero value. Use [Apply] to map a parsed [File] onto an +// existing [session.Config]; non-zero fields in the session config take +// precedence over the YAML values. // //nolint:tagliatelle // YAML keys are the documented config file schema. package config @@ -13,17 +12,46 @@ import ( "errors" "fmt" "os" + "path/filepath" + "strings" "github.com/openlibrecommunity/olcrtc/internal/app/session" "gopkg.in/yaml.v3" ) -// ErrConfigNotFound is returned when a config file path is set but the file does not exist. -var ErrConfigNotFound = errors.New("config file not found") +var ( + // ErrConfigNotFound is returned when a config file path is set but the file does not exist. + ErrConfigNotFound = errors.New("config file not found") + // ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured. + ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set") + // ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file. + ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty") +) // File is the on-disk YAML schema. type File struct { - Mode string `yaml:"mode"` + Mode string `yaml:"mode"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Gen Gen `yaml:"gen"` + Profiles []Profile `yaml:"profiles"` + Failover Failover `yaml:"failover"` + Data string `yaml:"data"` + Debug bool `yaml:"debug"` + FFmpeg string `yaml:"ffmpeg"` +} + +// Profile is a failover entry that overrides top-level runtime fields. +type Profile struct { + Name string `yaml:"name"` Link string `yaml:"link"` Auth Auth `yaml:"auth"` Room Room `yaml:"room"` @@ -34,10 +62,12 @@ type File struct { Video Video `yaml:"video"` VP8 VP8 `yaml:"vp8"` SEI SEI `yaml:"sei"` - Gen Gen `yaml:"gen"` - Data string `yaml:"data"` - Debug bool `yaml:"debug"` - FFmpeg string `yaml:"ffmpeg"` +} + +// Failover controls ordered profile failover. +type Failover struct { + RetryDelay string `yaml:"retry_delay"` + MaxCycles int `yaml:"max_cycles"` } // Auth selects the auth provider. @@ -52,7 +82,8 @@ type Room struct { // Crypto holds the shared secret used to authenticate and encrypt the tunnel. type Crypto struct { - Key string `yaml:"key"` // 64-char hex (32 bytes) + Key string `yaml:"key"` // 64-char hex (32 bytes) + KeyFile string `yaml:"key_file"` // path to a file containing crypto.key } // Net groups network and transport selection. @@ -125,9 +156,63 @@ func Load(path string) (File, error) { if err := yaml.Unmarshal(data, &f); err != nil { return File{}, fmt.Errorf("parse config %s: %w", path, err) } + if err := loadExternalSecrets(path, &f); err != nil { + return File{}, err + } return f, nil } +func loadExternalSecrets(configPath string, f *File) error { + if f.Crypto.KeyFile == "" { + return loadProfileSecrets(configPath, f.Profiles) + } + if f.Crypto.Key != "" { + return ErrCryptoKeyConflict + } + + key, err := readKeyFile(configPath, f.Crypto.KeyFile) + if err != nil { + return err + } + f.Crypto.Key = key + return loadProfileSecrets(configPath, f.Profiles) +} + +func loadProfileSecrets(configPath string, profiles []Profile) error { + for i := range profiles { + if profiles[i].Crypto.KeyFile == "" { + continue + } + if profiles[i].Crypto.Key != "" { + return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict) + } + key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile) + if err != nil { + return fmt.Errorf("profiles[%d]: %w", i, err) + } + profiles[i].Crypto.Key = key + } + return nil +} + +func readKeyFile(configPath, keyFile string) (string, error) { + keyPath := keyFile + if !filepath.IsAbs(keyPath) { + keyPath = filepath.Join(filepath.Dir(configPath), keyPath) + } + + // #nosec G304 -- key_file is an explicit path in the user's config file. + data, err := os.ReadFile(keyPath) + if err != nil { + return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err) + } + key := strings.TrimSpace(string(data)) + if key == "" { + return "", ErrCryptoKeyFileEmpty + } + return key, nil +} + // Apply merges f onto dst. CLI-set fields (non-zero values in dst) win; // YAML values fill in the rest. func Apply(dst session.Config, f File) session.Config { @@ -167,6 +252,43 @@ func Apply(dst session.Config, f File) session.Config { return dst } +// ApplyProfile overlays a failover profile onto an already-applied base config. +func ApplyProfile(base session.Config, p Profile) session.Config { + dst := base + dst.Link = overlayString(dst.Link, p.Link) + dst.Transport = overlayString(dst.Transport, p.Net.Transport) + dst.Auth = overlayString(dst.Auth, p.Auth.Provider) + dst.Engine = overlayString(dst.Engine, p.Engine.Name) + dst.URL = overlayString(dst.URL, p.Engine.URL) + dst.Token = overlayString(dst.Token, p.Engine.Token) + dst.RoomID = overlayString(dst.RoomID, p.Room.ID) + dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key) + dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host) + dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port) + dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User) + dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass) + dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS) + dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr) + dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort) + dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width) + dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height) + dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS) + dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate) + dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW) + dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize) + dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery) + dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec) + dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule) + dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS) + dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS) + dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize) + dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS) + dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize) + dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize) + dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS) + return dst +} + func pickString(cli, yamlVal string) string { if cli != "" { return cli @@ -180,3 +302,17 @@ func pickInt(cli, yamlVal int) int { } return yamlVal } + +func overlayString(base, override string) string { + if override != "" { + return override + } + return base +} + +func overlayInt(base, override int) int { + if override != 0 { + return override + } + return base +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 95c4d9b..7504110 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "path/filepath" "testing" @@ -121,6 +122,157 @@ func TestApplyCLIWins(t *testing.T) { } } +func TestLoadAndApplyProfile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +mode: srv +link: direct +crypto: + key: shared-key +net: + dns: 1.1.1.1:53 +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: wb-room + net: + transport: vp8channel + vp8: + fps: 30 + - name: jitsi-dc + auth: + provider: jitsi + room: + id: https://meet.example/room + net: + transport: datachannel + dns: 8.8.8.8:53 +failover: + retry_delay: 100ms + max_cycles: 2 +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(f.Profiles) != 2 { + t.Fatalf("profiles = %d, want 2", len(f.Profiles)) + } + if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 { + t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover) + } + + base := Apply(session.Config{}, f) + first := ApplyProfile(base, f.Profiles[0]) + if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" { + t.Fatalf("first profile = %+v", first) + } + if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 { + t.Fatalf("first inherited/overlaid fields = %+v", first) + } + second := ApplyProfile(base, f.Profiles[1]) + if second.Auth != "jitsi" || second.Transport != "datachannel" || + second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" { + t.Fatalf("second profile = %+v", second) + } +} + +func TestLoadProfileCryptoKeyFile(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +profiles: + - name: file-key + crypto: + key_file: profile.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := f.Profiles[0].Crypto.Key; got != testCryptoKey { + t.Fatalf("profile key = %q, want %q", got, testCryptoKey) + } +} + +func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "secret.key") + if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +mode: srv +crypto: + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if f.Crypto.Key != testCryptoKey { + t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey) + } +} + +func TestLoadCryptoKeyFileConflict(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +crypto: + key: deadbeef + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if !errors.Is(err, ErrCryptoKeyConflict) { + t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict) + } +} + +func TestLoadCryptoKeyFileEmpty(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "secret.key") + if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +crypto: + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if !errors.Is(err, ErrCryptoKeyFileEmpty) { + t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty) + } +} + func TestLoadMissing(t *testing.T) { _, err := Load(filepath.Join(t.TempDir(), "nope.yaml")) if err == nil { diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 835bd65..b5cf0dd 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -24,6 +24,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/server" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" "github.com/pion/webrtc/v4" ) @@ -47,6 +48,7 @@ var ( errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply") errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting") errPayloadMismatchOffset = errors.New("payload mismatch at offset") + errFailoverCarrier = errors.New("intentional failover carrier failure") ) var ( @@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) { }) } +func registerFailingCarrier(t *testing.T) string { + t.Helper() + session.RegisterDefaults() + + name := "e2e-fail-" + t.Name() + carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { + return nil, errFailoverCarrier + }) + return name +} + func builtInCarrierNames() []string { return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional } @@ -1163,6 +1176,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { } } +func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) { + echoAddr := startEchoServer(t) + failingCarrier := registerFailingCarrier(t) + memoryCarrier, room := registerMemoryCarrier(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + socksAddr := freeLocalAddr(ctx, t) + socksHost, socksPort := splitHostPort(t, socksAddr) + + serverProfiles := []supervisor.Profile{ + {Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)}, + {Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)}, + } + clientProfiles := []supervisor.Profile{ + {Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)}, + {Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)}, + } + + started := make(chan string, 8) + serverErr := make(chan error, 1) + go func() { + serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run) + }() + room.waitConnected(t, 1) + + ready := make(chan struct{}) + var readyOnce sync.Once + clientErr := make(chan error, 1) + go func() { + clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error { + return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() { + if cfg.Auth == memoryCarrier { + readyOnce.Do(func() { close(ready) }) + } + }) + }) + }() + + waitForReady(t, ready) + conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr) + defer func() { _ = conn.Close() }() + + payload := []byte("olcrtc-failover-e2e\n") + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write failover payload: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + t.Fatalf("set failover read deadline: %v", err) + } + line, err := bufio.NewReader(conn).ReadBytes('\n') + if err != nil { + t.Fatalf("read failover echo: %v", err) + } + if !bytes.Equal(line, payload) { + t.Fatalf("failover echo = %q, want %q", line, payload) + } + + requireStartedProfiles(t, started, []string{ + "server:failing-server", + "server:memory-server", + "client:failing-client", + "client:memory-client", + }) + + cancel() + waitSupervisorStopped(t, "client", clientErr) + waitSupervisorStopped(t, "server", serverErr) +} + +func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config { + cfg := session.Config{ + Mode: mode, + Link: linkDirect, + Transport: transportData, + Auth: carrierName, + RoomID: testRoom, + KeyHex: testKeyHex, + DNSServer: localDNSServer, + } + if mode == "cnc" { + cfg.SOCKSHost = socksHost + cfg.SOCKSPort = socksPort + } + return cfg +} + +func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config { + return client.Config{ + Link: cfg.Link, + Transport: cfg.Transport, + Carrier: cfg.Auth, + RoomURL: cfg.RoomID, + KeyHex: cfg.KeyHex, + LocalAddr: socksAddr, + DNSServer: cfg.DNSServer, + DeviceID: testClientDeviceID, + VideoWidth: cfg.VideoWidth, + VideoHeight: cfg.VideoHeight, + VideoFPS: cfg.VideoFPS, + VideoBitrate: cfg.VideoBitrate, + VideoHW: cfg.VideoHW, + VideoQRSize: cfg.VideoQRSize, + VideoQRRecovery: cfg.VideoQRRecovery, + VideoCodec: cfg.VideoCodec, + VideoTileModule: cfg.VideoTileModule, + VideoTileRS: cfg.VideoTileRS, + VP8FPS: cfg.VP8FPS, + VP8BatchSize: cfg.VP8BatchSize, + SEIFPS: cfg.SEIFPS, + SEIBatchSize: cfg.SEIBatchSize, + SEIFragmentSize: cfg.SEIFragmentSize, + SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Engine: cfg.Engine, + URL: cfg.URL, + Token: cfg.Token, + } +} + +func failoverE2EConfig( + profiles []supervisor.Profile, + started chan<- string, + side string, +) supervisor.Config { + return supervisor.Config{ + Profiles: profiles, + RetryDelay: time.Millisecond, + OnProfileStart: func(profile supervisor.Profile, _ int) { + select { + case started <- side + ":" + profile.Name: + default: + } + }, + } +} + +func splitHostPort(t *testing.T, addr string) (string, int) { + t.Helper() + host, portText, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("split host port %q: %v", addr, err) + } + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatalf("parse port %q: %v", portText, err) + } + return host, port +} + +func requireStartedProfiles(t *testing.T, started <-chan string, want []string) { + t.Helper() + seen := make(map[string]bool) + deadline := time.After(3 * time.Second) + for len(seen) < len(want) { + select { + case item := <-started: + seen[item] = true + case <-deadline: + t.Fatalf("started profiles = %v, want all %v", seen, want) + } + } + for _, item := range want { + if !seen[item] { + t.Fatalf("started profiles = %v, missing %s", seen, item) + } + } +} + +func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) { + t.Helper() + select { + case err := <-ch: + if err != nil { + t.Fatalf("%s supervisor returned error: %v", name, err) + } + case <-time.After(3 * time.Second): + t.Fatalf("%s supervisor did not stop", name) + } +} + func TestEndedCallbackStopsClientAndServer(t *testing.T) { rt := startTunnel(t) rt.room.triggerEnded("conference ended") diff --git a/internal/engine/livekit/livekit.go b/internal/engine/livekit/livekit.go index 24c62bd..ad7e64d 100644 --- a/internal/engine/livekit/livekit.go +++ b/internal/engine/livekit/livekit.go @@ -19,13 +19,17 @@ import ( protoLogger "github.com/livekit/protocol/logger" lksdk "github.com/livekit/server-sdk-go/v2" "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/pion/webrtc/v4" ) const ( - defaultSendQueueSize = 5000 - dataPublishTopic = "olcrtc" - videoTrackName = "videochannel" + defaultSendQueueSize = 5000 + defaultSendQueueCapHard = 4000 + dataPublishTopic = "olcrtc" + videoTrackName = "videochannel" + reconnectWindow = 5 * time.Minute + maxReconnects = 10 ) var ( @@ -41,20 +45,98 @@ var ( ErrTokenRequired = errors.New("livekit access token required") ) +type roomHandle interface { + publishData([]byte) error + publishTrack(webrtc.TrackLocal) error + unpublishLocalTracks() + disconnect() + connectionState() lksdk.ConnectionState +} + +type sdkRoom struct { + room *lksdk.Room +} + +func (r *sdkRoom) publishData(data []byte) error { + return r.room.LocalParticipant.PublishDataPacket( + lksdk.UserData(data), + lksdk.WithDataPublishTopic(dataPublishTopic), + lksdk.WithDataPublishReliable(true), + ) +} + +func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error { + _, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName}) + return err +} + +func (r *sdkRoom) unpublishLocalTracks() { + if r.room == nil || r.room.LocalParticipant == nil { + return + } + for _, publication := range r.room.LocalParticipant.TrackPublications() { + if publication.SID() == "" { + continue + } + if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil { + log.Printf("livekit unpublish track error: %v", err) + } + } +} + +func (r *sdkRoom) disconnect() { + r.room.Disconnect() + // LiveKit's Disconnect returns after local SDK teardown, before the + // server necessarily evicts the participant. Give the signalling path a + // short grace period so immediate reconnects do not inherit stale room + // state from a ghost participant. + time.Sleep(2 * time.Second) +} + +func (r *sdkRoom) connectionState() lksdk.ConnectionState { + return r.room.ConnectionState() +} + +type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) + +func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) { + room, err := lksdk.ConnectToRoomWithToken( + url, + token, + callback, + lksdk.WithAutoSubscribe(true), + lksdk.WithLogger(protoLogger.GetDiscardLogger()), + ) + if err != nil { + return nil, err + } + return &sdkRoom{room: room}, nil +} + // Session is the LiveKit engine handle. type Session struct { url string token string name string - room *lksdk.Room + refresh func(ctx context.Context) (engine.Credentials, error) + connectRoom connectRoomFunc + room roomHandle + roomMu sync.RWMutex onData func([]byte) onReconnect func(*webrtc.DataChannel) shouldReconnect func() bool onEnded func(string) + reconnectCh chan struct{} + closeCh chan struct{} + lastReconnect time.Time + reconnectCount int sendQueue chan []byte closed atomic.Bool + reconnecting atomic.Bool done chan struct{} cancel context.CancelFunc + shutdownOnce sync.Once + sendWorkerOnce sync.Once videoTrackMu sync.RWMutex videoTracks []webrtc.TrackLocal onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver) @@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) { } _, cancel := context.WithCancel(ctx) return &Session{ - url: cfg.URL, - token: cfg.Token, - name: cfg.Name, - onData: cfg.OnData, - sendQueue: make(chan []byte, defaultSendQueueSize), - done: make(chan struct{}), - cancel: cancel, + url: cfg.URL, + token: cfg.Token, + name: cfg.Name, + refresh: cfg.Refresh, + connectRoom: connectSDKRoom, + onData: cfg.OnData, + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + cancel: cancel, }, nil } @@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities { } // Connect joins the LiveKit room. -func (s *Session) Connect(_ context.Context) error { +func (s *Session) Connect(ctx context.Context) error { + s.closed.Store(false) + if err := s.connectSession(ctx); err != nil { + return err + } + s.startSendWorker() + return nil +} + +func (s *Session) connectSession(_ context.Context) error { roomCB := &lksdk.RoomCallback{ ParticipantCallback: lksdk.ParticipantCallback{ OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) { @@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error { }, }, OnDisconnected: func() { - if !s.closed.Load() && s.onEnded != nil { - s.onEnded("disconnected from livekit") + if s.closed.Load() || s.reconnecting.Load() { + return + } + if !s.queueReconnect() { + s.signalEnded("disconnected from livekit") } }, } - room, err := lksdk.ConnectToRoomWithToken( - s.url, - s.token, - roomCB, - lksdk.WithAutoSubscribe(true), - lksdk.WithLogger(protoLogger.GetDiscardLogger()), - ) + room, err := s.connectRoom(s.url, s.token, roomCB) if err != nil { return fmt.Errorf("connect to room: %w", err) } - s.room = room + s.setRoom(room) if err := s.publishPendingTracks(); err != nil { return err } - s.wg.Add(1) - go s.processSendQueue() return nil } func (s *Session) publishPendingTracks() error { + room := s.currentRoom() + if room == nil { + return ErrRoomNotConnected + } s.videoTrackMu.RLock() defer s.videoTrackMu.RUnlock() for _, track := range s.videoTracks { - if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ - Name: videoTrackName, - }); err != nil { + if err := room.publishTrack(track); err != nil { return fmt.Errorf("failed to publish track: %w", err) } } return nil } +func (s *Session) startSendWorker() { + s.sendWorkerOnce.Do(func() { + s.wg.Add(1) + go s.processSendQueue() + }) +} + func (s *Session) processSendQueue() { defer s.wg.Done() for { @@ -157,17 +256,33 @@ func (s *Session) processSendQueue() { if !ok { return } - if err := s.room.LocalParticipant.PublishDataPacket( - lksdk.UserData(data), - lksdk.WithDataPublishTopic(dataPublishTopic), - lksdk.WithDataPublishReliable(true), - ); err != nil { + room := s.waitForConnectedRoom() + if room == nil { + return + } + if err := room.publishData(data); err != nil { log.Printf("livekit publish data error: %v", err) } } } } +func (s *Session) waitForConnectedRoom() roomHandle { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + room := s.currentRoom() + if room != nil && room.connectionState() == lksdk.ConnectionStateConnected { + return room + } + select { + case <-s.done: + return nil + case <-ticker.C: + } + } +} + // Send queues data for transmission. func (s *Session) Send(data []byte) error { if s.closed.Load() { @@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error { // Close terminates the session. func (s *Session) Close() error { - if s.closed.CompareAndSwap(false, true) { - s.cancel() - close(s.done) - if s.room != nil { - s.unpublishLocalTracks() - s.room.Disconnect() - // LiveKit's Disconnect() returns once the local SDK state - // is torn down, not when the server has actually evicted - // the participant. Without giving the signalling channel - // time to flush the LEAVE_REQUEST and the server to act on - // it, a back-to-back reconnect from the same identity in - // the same room sees a still-alive ghost participant on - // the SFU and inherits stale publication state. - time.Sleep(2 * time.Second) - } - close(s.sendQueue) - s.wg.Wait() - } + s.closed.Store(true) + s.shutdown() return nil } -func (s *Session) unpublishLocalTracks() { - if s.room == nil || s.room.LocalParticipant == nil { - return - } - for _, publication := range s.room.LocalParticipant.TrackPublications() { - if publication.SID() == "" { - continue +func (s *Session) shutdown() { + s.shutdownOnce.Do(func() { + if s.cancel != nil { + s.cancel() } - if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil { - log.Printf("livekit unpublish track error: %v", err) + closeSignal(s.closeCh) + closeSignal(s.done) + if room := s.swapRoom(nil); room != nil { + room.unpublishLocalTracks() + room.disconnect() } - } + s.wg.Wait() + }) } -// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity). +// SetReconnectCallback stores the reconnect callback. func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb } -// SetShouldReconnect stores the reconnect predicate (kept for API parity). +// SetShouldReconnect stores the reconnect predicate. func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn } // SetEndedCallback registers a function to call when the session ends. func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb } -// WatchConnection is a no-op; LiveKit handles connection supervision itself. -func (s *Session) WatchConnection(_ context.Context) {} +// WatchConnection monitors the connection lifecycle and reconnects as needed. +func (s *Session) WatchConnection(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-s.closeCh: + return + case <-s.reconnectCh: + if s.handleReconnectAttempt(ctx) { + return + } + } + } +} + +func (s *Session) handleReconnectAttempt(ctx context.Context) bool { + if time.Since(s.lastReconnect) > reconnectWindow { + s.reconnectCount = 0 + } + s.reconnectCount++ + s.lastReconnect = time.Now() + + if s.reconnectCount > maxReconnects { + s.signalEnded("reconnect limit reached") + return true + } + + backoff := time.Duration(s.reconnectCount) * 2 * time.Second + if backoff > 30*time.Second { + backoff = 30 * time.Second + } + + for { + if err := s.reconnect(ctx); err != nil { + logger.Debugf("livekit reconnect failed: %v", err) + select { + case <-ctx.Done(): + return true + case <-s.closeCh: + return true + case <-time.After(backoff): + continue + } + } + s.drainReconnectQueue() + return false + } +} + +func (s *Session) reconnect(ctx context.Context) error { + s.reconnecting.Store(true) + defer s.reconnecting.Store(false) + + if room := s.swapRoom(nil); room != nil { + room.unpublishLocalTracks() + room.disconnect() + } + + if s.refresh != nil { + creds, err := s.refresh(ctx) + if err != nil { + return fmt.Errorf("refresh credentials: %w", err) + } + s.applyRefreshedCredentials(creds) + } + + if err := s.connectSession(ctx); err != nil { + return err + } + if s.onReconnect != nil { + s.onReconnect(nil) + } + return nil +} + +func (s *Session) applyRefreshedCredentials(creds engine.Credentials) { + if creds.URL != "" { + s.url = creds.URL + } + if creds.Token != "" { + s.token = creds.Token + } +} + +func (s *Session) queueReconnect() bool { + if s.closed.Load() || s.reconnecting.Load() { + return false + } + if s.shouldReconnect != nil && !s.shouldReconnect() { + return false + } + select { + case s.reconnectCh <- struct{}{}: + default: + } + return true +} + +func (s *Session) drainReconnectQueue() { + for { + select { + case <-s.reconnectCh: + default: + return + } + } +} + +func (s *Session) signalEnded(reason string) { + s.closed.Store(true) + s.shutdown() + if s.onEnded != nil { + s.onEnded(reason) + } +} // CanSend reports whether the session is ready to accept data. -func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil } +func (s *Session) CanSend() bool { + if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard { + return false + } + room := s.currentRoom() + return room != nil && room.connectionState() == lksdk.ConnectionStateConnected +} // GetSendQueue exposes the outbound queue. func (s *Session) GetSendQueue() chan []byte { return s.sendQueue } @@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error { s.videoTracks = append(s.videoTracks, track) s.videoTrackMu.Unlock() - if s.room == nil || s.room.LocalParticipant == nil { + room := s.currentRoom() + if room == nil { return nil } - if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ - Name: videoTrackName, - }); err != nil { + if err := room.publishTrack(track); err != nil { return fmt.Errorf("failed to publish track: %w", err) } return nil @@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR s.onVideoTrack = cb } +func (s *Session) currentRoom() roomHandle { + s.roomMu.RLock() + defer s.roomMu.RUnlock() + return s.room +} + +func (s *Session) setRoom(room roomHandle) { + s.roomMu.Lock() + defer s.roomMu.Unlock() + s.room = room +} + +func (s *Session) swapRoom(room roomHandle) roomHandle { + s.roomMu.Lock() + defer s.roomMu.Unlock() + old := s.room + s.room = room + return old +} + +func closeSignal(ch chan struct{}) { + select { + case <-ch: + default: + close(ch) + } +} + func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins engine.Register("livekit", New) } diff --git a/internal/engine/livekit/livekit_test.go b/internal/engine/livekit/livekit_test.go new file mode 100644 index 0000000..7a46fd5 --- /dev/null +++ b/internal/engine/livekit/livekit_test.go @@ -0,0 +1,306 @@ +package livekit + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + lksdk "github.com/livekit/server-sdk-go/v2" + "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/pion/webrtc/v4" +) + +type fakeRoom struct { + mu sync.Mutex + state lksdk.ConnectionState + published [][]byte + tracks int + unpublished int + disconnected int +} + +func newFakeRoom() *fakeRoom { + return &fakeRoom{state: lksdk.ConnectionStateConnected} +} + +func (r *fakeRoom) publishData(data []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + r.published = append(r.published, append([]byte(nil), data...)) + return nil +} + +func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error { + r.mu.Lock() + defer r.mu.Unlock() + r.tracks++ + return nil +} + +func (r *fakeRoom) unpublishLocalTracks() { + r.mu.Lock() + defer r.mu.Unlock() + r.unpublished++ +} + +func (r *fakeRoom) disconnect() { + r.mu.Lock() + defer r.mu.Unlock() + r.disconnected++ + r.state = lksdk.ConnectionStateDisconnected +} + +func (r *fakeRoom) connectionState() lksdk.ConnectionState { + r.mu.Lock() + defer r.mu.Unlock() + return r.state +} + +type fakeConnector struct { + mu sync.Mutex + urls []string + tokens []string + callbacks []*lksdk.RoomCallback + rooms []*fakeRoom + connected chan struct{} + err error +} + +func newFakeConnector() *fakeConnector { + return &fakeConnector{connected: make(chan struct{}, 8)} +} + +func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + return nil, c.err + } + room := newFakeRoom() + c.urls = append(c.urls, url) + c.tokens = append(c.tokens, token) + c.callbacks = append(c.callbacks, cb) + c.rooms = append(c.rooms, room) + c.connected <- struct{}{} + return room, nil +} + +func (c *fakeConnector) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.rooms) +} + +func (c *fakeConnector) callback(i int) *lksdk.RoomCallback { + c.mu.Lock() + defer c.mu.Unlock() + return c.callbacks[i] +} + +func (c *fakeConnector) room(i int) *fakeRoom { + c.mu.Lock() + defer c.mu.Unlock() + return c.rooms[i] +} + +func (c *fakeConnector) snapshot() ([]string, []string) { + c.mu.Lock() + defer c.mu.Unlock() + return append([]string(nil), c.urls...), append([]string(nil), c.tokens...) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("condition was not met before timeout") +} + +func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + refreshes := 0 + sess, err := New(ctx, engine.Config{ + URL: "wss://old", + Token: "old-token", + Refresh: func(context.Context) (engine.Credentials, error) { + refreshes++ + return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil + }, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + s := sess.(*Session) + connector := newFakeConnector() + s.connectRoom = connector.connect + + reconnected := make(chan struct{}, 1) + s.SetReconnectCallback(func(*webrtc.DataChannel) { + reconnected <- struct{}{} + }) + + if err := s.Connect(ctx); err != nil { + t.Fatalf("Connect() error = %v", err) + } + go s.WatchConnection(ctx) + + connector.callback(0).OnDisconnected() + + waitFor(t, func() bool { return connector.count() == 2 }) + select { + case <-reconnected: + case <-time.After(time.Second): + t.Fatal("reconnect callback was not called") + } + + urls, tokens := connector.snapshot() + if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) { + t.Fatalf("connect urls = %v, want %v", got, want) + } + if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) { + t.Fatalf("connect tokens = %v, want %v", got, want) + } + if refreshes != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } + oldRoom := connector.room(0) + oldRoom.mu.Lock() + if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 { + t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1", + oldRoom.disconnected, oldRoom.unpublished) + } + oldRoom.mu.Unlock() + if !s.CanSend() { + t.Fatal("CanSend() = false after reconnect, want true") + } + + if err := s.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } +} + +func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) { + ctx := context.Background() + sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + s := sess.(*Session) + connector := newFakeConnector() + s.connectRoom = connector.connect + s.SetShouldReconnect(func() bool { return false }) + + ended := make(chan string, 1) + s.SetEndedCallback(func(reason string) { + ended <- reason + }) + + if err := s.Connect(ctx); err != nil { + t.Fatalf("Connect() error = %v", err) + } + connector.callback(0).OnDisconnected() + + select { + case reason := <-ended: + if reason != "disconnected from livekit" { + t.Fatalf("ended reason = %q, want disconnected from livekit", reason) + } + case <-time.After(time.Second): + t.Fatal("ended callback was not called") + } + if !s.closed.Load() { + t.Fatal("closed = false after terminal disconnect") + } + if connector.count() != 1 { + t.Fatalf("connect count = %d, want 1", connector.count()) + } + room := connector.room(0) + room.mu.Lock() + if room.disconnected != 1 || room.unpublished != 1 { + t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1", + room.disconnected, room.unpublished) + } + room.mu.Unlock() + + if err := s.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + room.mu.Lock() + if room.disconnected != 1 || room.unpublished != 1 { + t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1", + room.disconnected, room.unpublished) + } + room.mu.Unlock() +} + +func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) { + s := &Session{ + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + closeCh: make(chan struct{}), + } + if s.CanSend() { + t.Fatal("CanSend() = true without room") + } + + room := newFakeRoom() + room.state = lksdk.ConnectionStateDisconnected + s.setRoom(room) + if s.CanSend() { + t.Fatal("CanSend() = true for disconnected room") + } + + room.state = lksdk.ConnectionStateConnected + if !s.CanSend() { + t.Fatal("CanSend() = false for connected room") + } + + for i := 0; i < defaultSendQueueCapHard; i++ { + s.sendQueue <- []byte("x") + } + if s.CanSend() { + t.Fatal("CanSend() = true at queue high watermark") + } +} + +func TestReconnectFailureRetriesUntilContextDone(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := &Session{ + url: "wss://old", + token: "old-token", + connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) { + cancel() + return nil, errors.New("boom") + }, + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + } + if terminal := s.handleReconnectAttempt(ctx); !terminal { + t.Fatal("handleReconnectAttempt() = false after context cancellation") + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/supervisor/supervisor.go b/internal/supervisor/supervisor.go new file mode 100644 index 0000000..929fed6 --- /dev/null +++ b/internal/supervisor/supervisor.go @@ -0,0 +1,96 @@ +// Package supervisor runs ordered session profiles with failover. +package supervisor + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" +) + +const DefaultRetryDelay = 2 * time.Second + +var ( + // ErrNoProfiles is returned when the supervisor is started without profiles. + ErrNoProfiles = errors.New("supervisor: no profiles configured") + // ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes. + ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded") +) + +// Profile is one runnable session configuration in an ordered failover list. +type Profile struct { + Name string + Config session.Config +} + +// Runner starts one session profile and blocks until it ends or fails. +type Runner func(ctx context.Context, cfg session.Config) error + +// Config controls ordered failover behavior. +type Config struct { + Profiles []Profile + RetryDelay time.Duration + MaxCycles int + + OnProfileStart func(profile Profile, cycle int) + OnProfileEnd func(profile Profile, cycle int, err error) +} + +// Run starts profiles in order. If a profile exits while ctx is still active, +// the supervisor waits RetryDelay and advances to the next profile. +func Run(ctx context.Context, cfg Config, run Runner) error { + if len(cfg.Profiles) == 0 { + return ErrNoProfiles + } + if cfg.RetryDelay == 0 { + cfg.RetryDelay = DefaultRetryDelay + } + + var lastErr error + for cycle := 1; ; cycle++ { + for i, profile := range cfg.Profiles { + if ctx.Err() != nil { + return nil + } + if cfg.OnProfileStart != nil { + cfg.OnProfileStart(profile, cycle) + } + + err := run(ctx, profile.Config) + if ctx.Err() != nil { + return nil + } + if err != nil { + lastErr = fmt.Errorf("profile %q: %w", profile.Name, err) + } else { + lastErr = fmt.Errorf("profile %q ended", profile.Name) + } + if cfg.OnProfileEnd != nil { + cfg.OnProfileEnd(profile, cycle, err) + } + + if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 { + return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr) + } + if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil { + return nil + } + } + } +} + +func waitRetryDelay(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/supervisor/supervisor_test.go b/internal/supervisor/supervisor_test.go new file mode 100644 index 0000000..aab0dee --- /dev/null +++ b/internal/supervisor/supervisor_test.go @@ -0,0 +1,85 @@ +package supervisor + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" +) + +var errRunnerBoom = errors.New("boom") + +func TestRunRequiresProfiles(t *testing.T) { + err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil }) + if !errors.Is(err, ErrNoProfiles) { + t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles) + } +} + +func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) { + profiles := []Profile{ + {Name: "first", Config: session.Config{Auth: "wbstream"}}, + {Name: "second", Config: session.Config{Auth: "jitsi"}}, + } + var started []string + var ended []string + err := Run(context.Background(), Config{ + Profiles: profiles, + RetryDelay: -1, + MaxCycles: 1, + OnProfileStart: func(profile Profile, cycle int) { + started = append(started, profile.Name) + if cycle != 1 { + t.Fatalf("cycle = %d, want 1", cycle) + } + }, + OnProfileEnd: func(profile Profile, _ int, err error) { + ended = append(ended, profile.Name) + if !errors.Is(err, errRunnerBoom) { + t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom) + } + }, + }, func(_ context.Context, cfg session.Config) error { + if cfg.Auth == "" { + t.Fatal("runner received empty auth") + } + return errRunnerBoom + }) + if !errors.Is(err, ErrMaxCyclesExceeded) { + t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded) + } + if got, want := started, []string{"first", "second"}; !equalStrings(got, want) { + t.Fatalf("started = %v, want %v", got, want) + } + if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) { + t.Fatalf("ended = %v, want %v", got, want) + } +} + +func TestRunReturnsNilOnContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + err := Run(ctx, Config{ + Profiles: []Profile{{Name: "one"}}, + RetryDelay: time.Hour, + }, func(context.Context, session.Config) error { + cancel() + return nil + }) + if err != nil { + t.Fatalf("Run() error = %v, want nil", err) + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} From b0fc3bd0f1ff668ab52d5654102ca8139cc1c59e Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:25:24 +0300 Subject: [PATCH 2/8] feat: add control stream liveness --- cmd/olcrtc/main.go | 3 +- docs/client.example.yaml | 5 + docs/configuration.md | 22 ++ docs/failover.example.yaml | 5 + docs/project-map.md | 19 +- docs/server.example.yaml | 5 + docs/settings.md | 10 +- internal/app/session/session.go | 150 ++++++++++--- internal/app/session/session_test.go | 47 ++++ internal/client/client.go | 91 +++++++- internal/client/client_test.go | 68 ++++++ internal/config/config.go | 37 ++- internal/config/config_test.go | 47 ++-- internal/control/control.go | 321 +++++++++++++++++++++++++++ internal/control/control_test.go | 128 +++++++++++ internal/handshake/handshake.go | 4 +- internal/server/server.go | 67 ++++-- internal/server/server_test.go | 72 ++++++ 18 files changed, 1012 insertions(+), 89 deletions(-) create mode 100644 internal/control/control.go create mode 100644 internal/control/control_test.go diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 777949b..af7b87f 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -140,6 +140,7 @@ func runWithConfig(cfg loadedConfig) error { return fmt.Errorf("validate config: %w", err) } scfg = session.ApplyTransportDefaults(scfg) + scfg = session.ApplyLivenessDefaults(scfg) if scfg.Mode == modeGen { if len(cfg.profiles) > 0 { @@ -166,7 +167,7 @@ func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error if err != nil { return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err) } - profile.Config = session.ApplyTransportDefaults(scfg) + profile.Config = session.ApplyLivenessDefaults(session.ApplyTransportDefaults(scfg)) out = append(out, profile) } return out, nil diff --git a/docs/client.example.yaml b/docs/client.example.yaml index fe83e0d..a074a6a 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -21,6 +21,11 @@ net: transport: datachannel # must match the server dns: "8.8.8.8:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + # Local SOCKS5 listener exposed to applications socks: host: "127.0.0.1" diff --git a/docs/configuration.md b/docs/configuration.md index 46edd07..8c067ad 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,6 +31,9 @@ olcrtc /etc/olcrtc/server.yaml | `video.*` | videochannel tuning | | `vp8.*` | vp8channel tuning | | `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning | +| `liveness.interval` | control-stream ping interval, default `10s` | +| `liveness.timeout` | pong timeout, default `5s` | +| `liveness.failures` | missed pongs before reconnect, default `3` | | `gen.amount` | gen mode: number of rooms to create | | `profiles[]` | ordered srv/cnc failover profiles | | `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | @@ -45,6 +48,25 @@ olcrtc /etc/olcrtc/server.yaml `crypto.key_file` is resolved relative to the YAML file. Do not set it together with `crypto.key`. +## Liveness + +After `CLIENT_HELLO` / `SERVER_WELCOME`, the first smux stream stays open as +an encrypted control stream. olcrtc now sends `CONTROL_PING` / `CONTROL_PONG` +messages over that stream to prove the real tunnel path still round-trips. +This detects states where a provider or WebRTC layer looks connected but the +encrypted smux path is no longer usable. + +```yaml +liveness: + interval: 10s + timeout: 5s + failures: 3 +``` + +When the failure threshold is reached, the current smux session is rebuilt. +In failover mode, a profile that exits after liveness-triggered reconnect +failure lets the supervisor advance to the next profile. + ## Failover Profiles `mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml index 7aa8149..e956a35 100644 --- a/docs/failover.example.yaml +++ b/docs/failover.example.yaml @@ -10,6 +10,11 @@ crypto: net: dns: "1.1.1.1:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + data: data profiles: diff --git a/docs/project-map.md b/docs/project-map.md index c4c8791..e1b2134 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -72,6 +72,7 @@ Important fields: | `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. | | `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | | `engine.*` | direct engine fields | Used only with `auth.provider: none`. | +| `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. | `internal/app/session` is the main router: @@ -151,6 +152,18 @@ SERVER_REJECT { version, reason } The handshake has a 64 KiB frame cap and a default 15 second timeout. +After handshake, `internal/control` keeps that same encrypted smux stream open +and exchanges length-prefixed JSON control messages: + +```text +CONTROL_PING { version, seq, sent_unix_nano } +CONTROL_PONG { version, seq, sent_unix_nano } +``` + +Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and +`liveness.failures: 3`. Missed pongs mark the smux session unhealthy and +trigger a session rebuild/reconnect path. + ## Registries And Plugin Shape The universal-carrier refactor centers on small registries: @@ -320,9 +333,9 @@ adaptive instead of static YAML knobs. ### 3. Control Stream Protocol -The first smux stream is parked after handshake. It is the natural place for: +The first smux stream now carries control ping/pong after handshake. It is +still the natural place for: -- Ping/pong and peer liveness. - Server policy updates. - Graceful reconnect notifications. - Drain/start markers for failover. @@ -330,7 +343,7 @@ The first smux stream is parked after handshake. It is the natural place for: Likely files: -- `internal/handshake` +- `internal/control` - `internal/server` - `internal/client` diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 9f5ee38..c20b1e5 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -23,6 +23,11 @@ net: transport: datachannel # datachannel | videochannel | seichannel | vp8channel dns: "8.8.8.8:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + # Outbound SOCKS5 proxy for server-side egress (optional) socks: proxy_addr: "" # e.g. "127.0.0.1" diff --git a/docs/settings.md b/docs/settings.md index 28855ce..2e2d78a 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -63,12 +63,20 @@ | `profiles` | Список профилей failover для `srv`/`cnc` | | `failover.retry_delay` | Пауза перед следующим профилем, например `2s` | | `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно | +| `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` | +| `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` | +| `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` | `crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. Если задан `profiles`, поля верхнего уровня становятся общими defaults, а каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и -настройки транспорта. Порядок профилей должен совпадать на сервере и клиенте. +настройки транспорта/liveness. Порядок профилей должен совпадать на сервере и +клиенте. + +`liveness` проверяет именно зашифрованный smux control stream после handshake, +а не только статус WebRTC/provider соединения. Если pong не приходит несколько +раз подряд, текущая smux-сессия пересоздается. --- diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 89de5f5..360d96a 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -13,6 +13,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/link/direct" "github.com/openlibrecommunity/olcrtc/internal/logger" @@ -120,43 +121,56 @@ var ( // ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication. ErrSOCKSAuthRequired = errors.New( "socks auth required when binding outside loopback (set socks.user and socks.pass)") + + // ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration. + ErrLivenessIntervalInvalid = errors.New( + "invalid liveness interval (set liveness.interval to a duration > 0)") + // ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration. + ErrLivenessTimeoutInvalid = errors.New( + "invalid liveness timeout (set liveness.timeout to a duration > 0)") + // ErrLivenessFailuresInvalid indicates that liveness.failures is not positive. + ErrLivenessFailuresInvalid = errors.New( + "invalid liveness failures (set liveness.failures to a value > 0)") ) // Config holds runtime session settings. type Config struct { - Mode string - Link string - Transport string - Auth string - Engine string - URL string - Token string - RoomID string - KeyHex string - SOCKSHost string - SOCKSPort int - SOCKSUser string - SOCKSPass string - DNSServer string - SOCKSProxyAddr string - SOCKSProxyPort int - VideoWidth int - VideoHeight int - VideoFPS int - VideoBitrate string - VideoHW string - VideoQRSize int - VideoQRRecovery string - VideoCodec string - VideoTileModule int - VideoTileRS int - VP8FPS int - VP8BatchSize int - SEIFPS int - SEIBatchSize int - SEIFragmentSize int - SEIAckTimeoutMS int - Amount int + Mode string + Link string + Transport string + Auth string + Engine string + URL string + Token string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + SOCKSUser string + SOCKSPass string + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + LivenessInterval string + LivenessTimeout string + LivenessFailures int + Amount int } // RegisterDefaults registers built-in carriers and transports. @@ -212,6 +226,20 @@ func ApplyTransportDefaults(cfg Config) Config { } } +// ApplyLivenessDefaults fills documented control-stream liveness defaults. +func ApplyLivenessDefaults(cfg Config) Config { + if cfg.LivenessInterval == "" { + cfg.LivenessInterval = control.DefaultInterval.String() + } + if cfg.LivenessTimeout == "" { + cfg.LivenessTimeout = control.DefaultTimeout.String() + } + if cfg.LivenessFailures == 0 { + cfg.LivenessFailures = control.DefaultFailures + } + return cfg +} + func applyVideoDefaults(cfg Config) Config { if cfg.VideoCodec == "" { cfg.VideoCodec = videoCodecQRCode @@ -292,6 +320,9 @@ func Validate(cfg Config) error { if err := validateTransportConfig(cfg); err != nil { return err } + if err := validateLivenessConfig(cfg); err != nil { + return err + } return validateModeConfig(cfg) } @@ -431,6 +462,52 @@ func validateModeConfig(cfg Config) error { return nil } +func validateLivenessConfig(cfg Config) error { + if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil { + return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err) + } + if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil { + return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err) + } + if cfg.LivenessFailures < 0 { + return ErrLivenessFailuresInvalid + } + return nil +} + +func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) { + if value == "" { + return def, nil + } + d, err := time.ParseDuration(value) + if err != nil { + return 0, err + } + if d <= 0 { + return 0, fmt.Errorf("duration must be > 0") + } + return d, nil +} + +func livenessConfig(cfg Config) (control.Config, error) { + interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval) + if err != nil { + return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err) + } + timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout) + if err != nil { + return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err) + } + failures := cfg.LivenessFailures + if failures == 0 { + failures = control.DefaultFailures + } + if failures < 0 { + return control.Config{}, ErrLivenessFailuresInvalid + } + return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil +} + func isLoopbackListenHost(host string) bool { if host == "localhost" { return true @@ -442,7 +519,12 @@ func isLoopbackListenHost(host string) bool { // Run starts the configured mode. func Run(ctx context.Context, cfg Config) error { cfg = ApplyTransportDefaults(cfg) + cfg = ApplyLivenessDefaults(cfg) roomURL := cfg.RoomID + liveness, err := livenessConfig(cfg) + if err != nil { + return err + } switch cfg.Mode { case modeSRV: @@ -474,6 +556,7 @@ func Run(ctx context.Context, cfg Config) error { Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, + Liveness: liveness, OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) { logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims) }, @@ -517,6 +600,7 @@ func Run(ctx context.Context, cfg Config) error { Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, + Liveness: liveness, }); err != nil { return fmt.Errorf("client: %w", err) } diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index f20e70d..95270b2 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "testing" + + "github.com/openlibrecommunity/olcrtc/internal/control" ) func TestApplyTransportDefaults(t *testing.T) { @@ -85,6 +87,24 @@ func TestApplyTransportDefaults(t *testing.T) { } } +func TestApplyLivenessDefaults(t *testing.T) { + got := ApplyLivenessDefaults(Config{}) + if got.LivenessInterval != control.DefaultInterval.String() { + t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String()) + } + if got.LivenessTimeout != control.DefaultTimeout.String() { + t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String()) + } + if got.LivenessFailures != control.DefaultFailures { + t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures) + } + + explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9} + if got := ApplyLivenessDefaults(explicit); got != explicit { + t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit) + } +} + //nolint:maintidx // table-driven validation test naturally has many cases func TestValidate(t *testing.T) { RegisterDefaults() @@ -422,6 +442,33 @@ func TestValidate(t *testing.T) { return cfg }(), }, + { + name: "liveness rejects bad interval", + cfg: func() Config { + cfg := base + cfg.LivenessInterval = "nope" + return cfg + }(), + want: ErrLivenessIntervalInvalid, + }, + { + name: "liveness rejects zero timeout", + cfg: func() Config { + cfg := base + cfg.LivenessTimeout = "0s" + return cfg + }(), + want: ErrLivenessTimeoutInvalid, + }, + { + name: "liveness rejects negative failures", + cfg: func() Config { + cfg := base + cfg.LivenessFailures = -1 + return cfg + }(), + want: ErrLivenessFailuresInvalid, + }, } for _, tt := range tests { diff --git a/internal/client/client.go b/internal/client/client.go index 0d73bd9..13be135 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -17,6 +17,7 @@ import ( "time" "github.com/google/uuid" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" @@ -54,7 +55,9 @@ type Client struct { conn *muxconn.Conn session *smux.Session controlStrm *smux.Stream + controlStop context.CancelFunc sessMu sync.RWMutex + reconnectMu sync.Mutex deviceID string sessionID string claims map[string]any @@ -93,6 +96,7 @@ type Config struct { Engine string URL string Token string + Liveness control.Config // DeviceID overrides the persistent client-side device identifier. Leave // empty to derive one from DeviceIDPath (or generate a random one if both @@ -217,7 +221,9 @@ func (c *Client) bringUpLink( if ctx.Err() != nil { return } - c.handleReconnect() + if !c.handleReconnect(ctx, cfg, cancel) { + cancel() + } }) if err := ln.Connect(ctx); err != nil { @@ -243,14 +249,15 @@ func (c *Client) bringUpLink( c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.startControlLoop(ctx, cfg, cancel, control) go ln.WatchConnection(ctx) return nil } // openControlStream opens stream #1 on sess and performs the handshake. -// The stream stays open for the lifetime of the smux session — the server -// holds it parked, and it would carry future control messages. +// The stream stays open for the lifetime of the smux session and carries +// post-handshake control messages. func openControlStream( sess *smux.Session, deviceID string, @@ -326,7 +333,10 @@ func smuxConfig() *smux.Config { return cfg } -func (c *Client) handleReconnect() { +func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc) bool { + c.reconnectMu.Lock() + defer c.reconnectMu.Unlock() + logger.Infof("client link reconnect - tearing down smux session") // Install a fresh muxconn immediately so onData never hits nil while @@ -336,14 +346,19 @@ func (c *Client) handleReconnect() { c.sessMu.Lock() oldControl := c.controlStrm + oldControlStop := c.controlStop oldSess := c.session oldConn := c.conn c.conn = newConn c.session = nil c.controlStrm = nil + c.controlStop = nil c.sessionID = "" c.sessMu.Unlock() + if oldControlStop != nil { + oldControlStop() + } if oldControl != nil { _ = oldControl.Close() } @@ -364,15 +379,25 @@ func (c *Client) handleReconnect() { attemptDelay = 300 * time.Millisecond ) for attempt := 1; attempt <= maxAttempts; attempt++ { - if c.tryReopenSession(attempt) { - return + if c.tryReopenSession(ctx, cfg, cancel, attempt) { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(attemptDelay): } - time.Sleep(attemptDelay) } logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts) + return false } -func (c *Client) tryReopenSession(attempt int) bool { +func (c *Client) tryReopenSession( + ctx context.Context, + cfg Config, + cancel context.CancelFunc, + attempt int, +) bool { conn := muxconn.New(c.ln, c.cipher) c.sessMu.Lock() @@ -400,19 +425,69 @@ func (c *Client) tryReopenSession(attempt int) bool { c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.startControlLoop(ctx, cfg, cancel, control) return true } +func (c *Client) startControlLoop( + ctx context.Context, + cfg Config, + cancel context.CancelFunc, + stream *smux.Stream, +) { + controlCtx, stop := context.WithCancel(ctx) + c.sessMu.Lock() + c.controlStop = stop + c.sessMu.Unlock() + + liveness := cfg.Liveness + onPong := liveness.OnPong + onUnhealthy := liveness.OnUnhealthy + liveness.OnPong = func(h control.Health) { + c.sessMu.RLock() + sid := c.sessionID + c.sessMu.RUnlock() + logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) + if onPong != nil { + onPong(h) + } + } + liveness.OnUnhealthy = func(missed int) { + logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed) + if onUnhealthy != nil { + onUnhealthy(missed) + } + } + + go func() { + err := control.Run(controlCtx, stream, liveness) + if controlCtx.Err() != nil || ctx.Err() != nil { + return + } + if err != nil { + logger.Warnf("client control stream ended: %v", err) + } + if !c.handleReconnect(ctx, cfg, cancel) { + cancel() + } + }() +} + func (c *Client) shutdown() { c.sessMu.Lock() control := c.controlStrm + controlStop := c.controlStop sess := c.session conn := c.conn c.controlStrm = nil + c.controlStop = nil c.session = nil c.conn = nil c.sessMu.Unlock() + if controlStop != nil { + controlStop() + } if conn != nil { _ = conn.Close() } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 48976fe..f5d836b 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/xtaci/smux" @@ -517,3 +518,70 @@ func TestShutdownClosesLinkAndConn(t *testing.T) { t.Fatal("shutdown() did not close link") } } + +func TestStartControlLoopReportsPong(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + peerStreamCh := make(chan *smux.Stream, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + peerStreamCh <- stream + } + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + peerStream := <-peerStreamCh + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + got := make(chan control.Health, 1) + c := &Client{sessionID: "sid-control"} + c.startControlLoop(ctx, Config{ + Liveness: control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h control.Health) { + select { + case got <- h: + default: + } + }, + }, + }, cancel, stream) + go func() { + _ = control.Run(ctx, peerStream, control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + }) + }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for control pong") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 5fe206c..9524363 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,7 @@ type File struct { Video Video `yaml:"video"` VP8 VP8 `yaml:"vp8"` SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` Gen Gen `yaml:"gen"` Profiles []Profile `yaml:"profiles"` Failover Failover `yaml:"failover"` @@ -51,17 +52,18 @@ type File struct { // Profile is a failover entry that overrides top-level runtime fields. type Profile struct { - Name string `yaml:"name"` - Link string `yaml:"link"` - Auth Auth `yaml:"auth"` - Room Room `yaml:"room"` - Crypto Crypto `yaml:"crypto"` - Net Net `yaml:"net"` - SOCKS SOCKS `yaml:"socks"` - Engine Engine `yaml:"engine"` - Video Video `yaml:"video"` - VP8 VP8 `yaml:"vp8"` - SEI SEI `yaml:"sei"` + Name string `yaml:"name"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` } // Failover controls ordered profile failover. @@ -137,6 +139,13 @@ type SEI struct { AckTimeoutMS int `yaml:"ack_timeout_ms"` } +// Liveness tunes the post-handshake control stream ping/pong checks. +type Liveness struct { + Interval string `yaml:"interval"` + Timeout string `yaml:"timeout"` + Failures int `yaml:"failures"` +} + // Gen controls room-generation mode. type Gen struct { Amount int `yaml:"amount"` @@ -248,6 +257,9 @@ func Apply(dst session.Config, f File) session.Config { dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize) dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize) dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS) + dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval) + dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout) + dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures) dst.Amount = pickInt(dst.Amount, f.Gen.Amount) return dst } @@ -286,6 +298,9 @@ func ApplyProfile(base session.Config, p Profile) session.Config { dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize) dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize) dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS) + dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval) + dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout) + dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures) return dst } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7504110..b41604c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,6 +39,10 @@ socks: vp8: fps: 25 batch_size: 4 +liveness: + interval: 2s + timeout: 500ms + failures: 4 gen: amount: 3 debug: true @@ -76,20 +80,23 @@ func requireLoadedFile(t *testing.T, f File) { func requireAppliedConfig(t *testing.T, got session.Config) { t.Helper() want := session.Config{ - Mode: testModeSrv, - Link: "direct", - Auth: testAuthProvider, - RoomID: testRoomID, - KeyHex: testCryptoKey, - Transport: "datachannel", - DNSServer: "1.1.1.1:53", - SOCKSHost: "127.0.0.1", - SOCKSPort: 1080, - SOCKSUser: "u", - SOCKSPass: "p", - VP8FPS: 25, - VP8BatchSize: 4, - Amount: 3, + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + LivenessInterval: "2s", + LivenessTimeout: "500ms", + LivenessFailures: 4, + Amount: 3, } if got != want { t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) @@ -132,6 +139,10 @@ crypto: key: shared-key net: dns: 1.1.1.1:53 +liveness: + interval: 5s + timeout: 2s + failures: 5 profiles: - name: wb-vp8 auth: @@ -142,6 +153,8 @@ profiles: transport: vp8channel vp8: fps: 30 + liveness: + interval: 1s - name: jitsi-dc auth: provider: jitsi @@ -174,7 +187,8 @@ failover: if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" { t.Fatalf("first profile = %+v", first) } - if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 { + if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 || + first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 { t.Fatalf("first inherited/overlaid fields = %+v", first) } second := ApplyProfile(base, f.Profiles[1]) @@ -182,6 +196,9 @@ failover: second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" { t.Fatalf("second profile = %+v", second) } + if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 { + t.Fatalf("second liveness fields = %+v", second) + } } func TestLoadProfileCryptoKeyFile(t *testing.T) { diff --git a/internal/control/control.go b/internal/control/control.go new file mode 100644 index 0000000..a6bd50f --- /dev/null +++ b/internal/control/control.go @@ -0,0 +1,321 @@ +// Package control implements the post-handshake control stream protocol. +// +// The control stream is the first smux stream after the olcrtc handshake. It +// stays inside the encrypted muxconn path, so ping/pong proves that the actual +// tunnel path still round-trips, not merely that the provider connection is up. +// +// Wire format matches the handshake framing: a 4-byte big-endian length +// followed by a JSON message. +// +//nolint:tagliatelle // JSON keys are the stable wire protocol schema. +package control + +import ( + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "time" +) + +const ( + // ProtoVersion identifies the control stream wire format. + ProtoVersion = 1 + // MaxMessageSize caps one control frame. + MaxMessageSize = 16 * 1024 + // DefaultInterval is the default interval between ping probes. + DefaultInterval = 10 * time.Second + // DefaultTimeout is the default time to wait for a pong. + DefaultTimeout = 5 * time.Second + // DefaultFailures is the default number of consecutive missed pongs before + // the stream is marked unhealthy. + DefaultFailures = 3 +) + +// MsgType labels a control message. +type MsgType string + +const ( + // TypePing is sent periodically to prove control-stream liveness. + TypePing MsgType = "CONTROL_PING" + // TypePong replies to a ping with the same sequence and timestamp. + TypePong MsgType = "CONTROL_PONG" +) + +var ( + // ErrUnhealthy is returned when the stream misses too many pong replies. + ErrUnhealthy = errors.New("control stream unhealthy") + // ErrProtocolVersion is returned when the peer announces an incompatible version. + ErrProtocolVersion = errors.New("incompatible control protocol version") + // ErrUnexpectedMessage is returned for unknown or malformed control message types. + ErrUnexpectedMessage = errors.New("unexpected control message") + // ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize]. + ErrFrameTooLarge = errors.New("control frame too large") +) + +// Message is one control-stream frame. +type Message struct { + Version int `json:"version"` + Type MsgType `json:"type"` + Seq uint64 `json:"seq,omitempty"` + SentUnixNano int64 `json:"sent_unix_nano,omitempty"` +} + +// Health is reported when a ping round trip completes. +type Health struct { + Seq uint64 + RTT time.Duration + LastSeen time.Time +} + +// Config controls the liveness loop. +type Config struct { + Interval time.Duration + Timeout time.Duration + Failures int + + // OnPong is called after a matching pong is received. + OnPong func(Health) + // OnUnhealthy is called before Run returns [ErrUnhealthy]. + OnUnhealthy func(missed int) +} + +func (cfg Config) withDefaults() Config { + if cfg.Interval <= 0 { + cfg.Interval = DefaultInterval + } + if cfg.Timeout <= 0 { + cfg.Timeout = DefaultTimeout + } + if cfg.Failures <= 0 { + cfg.Failures = DefaultFailures + } + return cfg +} + +// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes, +// or the configured failure threshold is reached. +func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error { + cfg = cfg.withDefaults() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + state := &state{ + rw: rw, + cfg: cfg, + pending: make(map[uint64]time.Time), + now: time.Now, + out: make(chan Message, 16), + } + + errCh := make(chan error, 3) + go func() { + <-ctx.Done() + _ = rw.Close() + }() + go func() { errCh <- state.readLoop(ctx) }() + go func() { errCh <- state.probeLoop(ctx) }() + go func() { errCh <- state.writeLoop(ctx) }() + + err := <-errCh + cancel() + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + return err +} + +type state struct { + rw io.ReadWriteCloser + cfg Config + now func() time.Time + + out chan Message + + mu sync.Mutex + pending map[uint64]time.Time + nextSeq uint64 + failures int +} + +func (s *state) readLoop(ctx context.Context) error { + for { + raw, err := readFrame(s.rw) + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + msg, err := parseMessage(raw) + if err != nil { + return err + } + switch msg.Type { + case TypePing: + if err := s.enqueue(ctx, Message{ + Version: ProtoVersion, + Type: TypePong, + Seq: msg.Seq, + SentUnixNano: msg.SentUnixNano, + }); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + case TypePong: + s.handlePong(msg) + default: + return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + } + } +} + +func (s *state) probeLoop(ctx context.Context) error { + ticker := time.NewTicker(s.cfg.Interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if err := s.sendProbe(ctx); err != nil { + return err + } + } + } +} + +func (s *state) sendProbe(ctx context.Context) error { + now := s.now() + + s.mu.Lock() + for seq, sent := range s.pending { + if now.Sub(sent) < s.cfg.Timeout { + continue + } + delete(s.pending, seq) + s.failures++ + } + if s.failures >= s.cfg.Failures { + missed := s.failures + s.mu.Unlock() + if s.cfg.OnUnhealthy != nil { + s.cfg.OnUnhealthy(missed) + } + return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed) + } + + s.nextSeq++ + seq := s.nextSeq + s.pending[seq] = now + s.mu.Unlock() + + return s.enqueue(ctx, Message{ + Version: ProtoVersion, + Type: TypePing, + Seq: seq, + SentUnixNano: now.UnixNano(), + }) +} + +func (s *state) handlePong(msg Message) { + now := s.now() + + s.mu.Lock() + sent, ok := s.pending[msg.Seq] + if ok { + delete(s.pending, msg.Seq) + s.failures = 0 + } + s.mu.Unlock() + + if !ok || s.cfg.OnPong == nil { + return + } + s.cfg.OnPong(Health{ + Seq: msg.Seq, + RTT: now.Sub(sent), + LastSeen: now, + }) +} + +func (s *state) enqueue(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.out <- msg: + return nil + } +} + +func (s *state) writeLoop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg := <-s.out: + if err := writeFrame(s.rw, msg); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + } + } +} + +func parseMessage(raw []byte) (Message, error) { + var msg Message + if err := json.Unmarshal(raw, &msg); err != nil { + return Message{}, fmt.Errorf("parse control message: %w", err) + } + if msg.Version != ProtoVersion { + return Message{}, fmt.Errorf("%w: peer v%d, local v%d", + ErrProtocolVersion, msg.Version, ProtoVersion) + } + if msg.Type != TypePing && msg.Type != TypePong { + return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + } + return msg, nil +} + +func writeFrame(w io.Writer, msg Message) error { + body, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal control message: %w", err) + } + if len(body) > MaxMessageSize { + return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize) + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize + if _, err := w.Write(hdr[:]); err != nil { + return fmt.Errorf("write control hdr: %w", err) + } + if _, err := w.Write(body); err != nil { + return fmt.Errorf("write control body: %w", err) + } + return nil +} + +func readFrame(r io.Reader) ([]byte, error) { + var hdr [4]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("read control hdr: %w", err) + } + n := binary.BigEndian.Uint32(hdr[:]) + if n > MaxMessageSize { + return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read control body: %w", err) + } + return buf, nil +} diff --git a/internal/control/control_test.go b/internal/control/control_test.go new file mode 100644 index 0000000..3c52bf6 --- /dev/null +++ b/internal/control/control_test.go @@ -0,0 +1,128 @@ +package control + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + "time" +) + +func controlPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + a, b := net.Pipe() + t.Cleanup(func() { + _ = a.Close() + _ = b.Close() + }) + return a, b +} + +func TestRunPingPongReportsRTT(t *testing.T) { + a, b := controlPair(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + got := make(chan Health, 1) + cfg := Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h Health) { + select { + case got <- h: + default: + } + }, + } + errCh := make(chan error, 2) + go func() { errCh <- Run(ctx, a, cfg) }() + go func() { errCh <- Run(ctx, b, cfg) }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + if h.RTT < 0 { + t.Fatalf("Health.RTT = %v", h.RTT) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for pong health") + } + + cancel() + for range 2 { + if err := <-errCh; err != nil { + t.Fatalf("Run() after cancel = %v", err) + } + } +} + +func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) { + a, b := controlPair(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _, _ = io.Copy(io.Discard, b) + }() + + missedCh := make(chan int, 1) + errCh := make(chan error, 1) + go func() { + errCh <- Run(ctx, a, Config{ + Interval: 10 * time.Millisecond, + Timeout: 5 * time.Millisecond, + Failures: 2, + OnUnhealthy: func(missed int) { missedCh <- missed }, + }) + }() + + select { + case err := <-errCh: + if !errors.Is(err, ErrUnhealthy) { + t.Fatalf("Run() error = %v, want ErrUnhealthy", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for unhealthy result") + } + if missed := <-missedCh; missed < 2 { + t.Fatalf("missed = %d, want >= 2", missed) + } +} + +func TestRunRejectsBadProtocolVersion(t *testing.T) { + a, b := controlPair(t) + errCh := make(chan error, 1) + go func() { + errCh <- Run(context.Background(), a, Config{Interval: time.Hour}) + }() + if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil { + t.Fatalf("writeFrame() error = %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, ErrProtocolVersion) { + t.Fatalf("Run() error = %v, want ErrProtocolVersion", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for protocol error") + } +} + +func TestReadFrameRejectsTooLarge(t *testing.T) { + a, b := controlPair(t) + go func() { + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1) + _, _ = b.Write(hdr[:]) + }() + _, err := readFrame(a) + if !errors.Is(err, ErrFrameTooLarge) { + t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err) + } +} diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index bec84a7..5d34f6f 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -13,8 +13,8 @@ // │ │ // // After the exchange the control stream stays open; tunnel traffic flows over -// additional smux streams opened by the client. The control stream may carry -// keepalives or future control messages. +// additional smux streams opened by the client. The control stream then +// carries ping/pong liveness and future control messages. // //nolint:tagliatelle // JSON keys are the stable wire protocol schema. package handshake diff --git a/internal/server/server.go b/internal/server/server.go index a720a25..4954ad4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" @@ -55,6 +56,7 @@ type Server struct { cipher *crypto.Cipher conn *muxconn.Conn session *smux.Session + controlStop context.CancelFunc sessMu sync.RWMutex reinstallMu sync.Mutex wg sync.WaitGroup @@ -68,6 +70,7 @@ type Server struct { resolver *net.Resolver socksProxyAddr string socksProxyPort int + liveness control.Config } // ConnectRequest is a message from the client to establish a new connection. @@ -106,6 +109,7 @@ type Config struct { Engine string URL string Token string + Liveness control.Config // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. @@ -155,6 +159,7 @@ func Run(ctx context.Context, cfg Config) error { dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, + liveness: cfg.Liveness, } s.setupResolver() @@ -340,13 +345,18 @@ func (s *Server) reinstallSession(dead *smux.Session) { } oldSess := s.session oldConn := s.conn + oldControlStop := s.controlStop oldSID := s.sessionID s.session = newSess s.conn = newConn + s.controlStop = nil s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + if oldControlStop != nil { + oldControlStop() + } if oldSess != nil { _ = oldSess.Close() } @@ -362,13 +372,18 @@ func (s *Server) closeSession() { s.sessMu.Lock() sess := s.session conn := s.conn + controlStop := s.controlStop s.session = nil s.conn = nil + s.controlStop = nil oldSID := s.sessionID s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + if controlStop != nil { + controlStop() + } if conn != nil { _ = conn.Close() } @@ -478,26 +493,48 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { s.sessMu.Unlock() s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) - // The control stream stays open for the lifetime of the session; - // keep it parked in a goroutine so the smux session does not close it. - s.wg.Add(1) - go func() { - defer s.wg.Done() - s.parkControlStream(stream) - }() + s.startControlLoop(ctx, sess, stream) return true } -// parkControlStream blocks reading from the control stream until it closes. -// Future control messages (kick, rate updates, etc.) would be dispatched here. -func (s *Server) parkControlStream(stream *smux.Stream) { - defer func() { _ = stream.Close() }() - buf := make([]byte, 64) - for { - if _, err := stream.Read(buf); err != nil { - return +func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) { + controlCtx, stop := context.WithCancel(ctx) + s.sessMu.Lock() + s.controlStop = stop + s.sessMu.Unlock() + + liveness := s.liveness + onPong := liveness.OnPong + onUnhealthy := liveness.OnUnhealthy + liveness.OnPong = func(h control.Health) { + s.sessMu.RLock() + sid := s.sessionID + s.sessMu.RUnlock() + logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) + if onPong != nil { + onPong(h) } } + liveness.OnUnhealthy = func(missed int) { + logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed) + if onUnhealthy != nil { + onUnhealthy(missed) + } + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer func() { _ = stream.Close() }() + err := control.Run(controlCtx, stream, liveness) + if controlCtx.Err() != nil || ctx.Err() != nil { + return + } + if err != nil { + logger.Warnf("server control stream ended: %v", err) + } + s.reinstallSession(sess) + }() } func (s *Server) shutdown() { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f6034bf..d5a6f6d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/xtaci/smux" @@ -373,6 +374,77 @@ func TestReinstallSessionFiresOnClose(t *testing.T) { } } +func TestStartControlLoopReportsPong(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + serverStreamCh := make(chan *smux.Stream, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + serverStreamCh <- stream + } + }() + + clientStream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + serverStream := <-serverStreamCh + + ctx, cancel := context.WithCancel(context.Background()) + got := make(chan control.Health, 1) + s := &Server{ + sessionID: "sid-control", + liveness: control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h control.Health) { + select { + case got <- h: + default: + } + }, + }, + } + defer func() { + cancel() + s.wg.Wait() + }() + s.startControlLoop(ctx, serverSess, serverStream) + go func() { + _ = control.Run(ctx, clientStream, control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + }) + }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for control pong") + } +} + //nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together. func TestDispatchFiresOnTraffic(t *testing.T) { var lc net.ListenConfig From d16cd0686ae6e5bf20288bc809bee88de1e2f629 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:30:04 +0300 Subject: [PATCH 3/8] feat: expose mobile liveness options --- mobile/mobile.go | 101 +++++++++++++++++++++++++++++++++++------- mobile/mobile_test.go | 60 +++++++++++++++++++++---- 2 files changed, 135 insertions(+), 26 deletions(-) diff --git a/mobile/mobile.go b/mobile/mobile.go index 0cf1a55..4ed9fc1 100644 --- a/mobile/mobile.go +++ b/mobile/mobile.go @@ -15,6 +15,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/protect" @@ -65,23 +66,26 @@ const ( ) var ( - mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional - defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional - defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional - registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional + mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional + defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional + defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional + registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional runClientWithReady = client.RunWithReady //nolint:gochecknoglobals // package-level state intentional - cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional - done chan struct{} //nolint:gochecknoglobals // package-level state intentional - ready chan struct{} //nolint:gochecknoglobals // package-level state intentional + cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional + done chan struct{} //nolint:gochecknoglobals // package-level state intentional + ready chan struct{} //nolint:gochecknoglobals // package-level state intentional errRun error ) type mobileConfig struct { - link string - transport string - dnsServer string - vp8FPS int - vp8BatchSize int + link string + transport string + dnsServer string + vp8FPS int + vp8BatchSize int + livenessInterval time.Duration + livenessTimeout time.Duration + livenessFailures int } // SetProtector sets the Android VPN socket protector. @@ -143,6 +147,21 @@ func SetVP8Options(fps, batchSize int) { defaults.vp8BatchSize = clampAtLeastOne(batchSize, 64) } +// SetLivenessOptions configures control-stream ping/pong checks. +// Values <= 0 reset that field to its default. Durations are milliseconds. +func SetLivenessOptions(intervalMillis, timeoutMillis, failures int) { + mu.Lock() + defer mu.Unlock() + ensureDefaultConfigLocked() + defaults.livenessInterval = durationFromMillisOrDefault(intervalMillis, control.DefaultInterval) + defaults.livenessTimeout = durationFromMillisOrDefault(timeoutMillis, control.DefaultTimeout) + if failures <= 0 { + defaults.livenessFailures = control.DefaultFailures + return + } + defaults.livenessFailures = failures +} + // SetDebug enables or disables verbose logging. func SetDebug(enabled bool) { logger.SetVerbose(enabled) @@ -195,6 +214,11 @@ func Check( vp8BatchSize int, ) (int64, error) { registerDefaults() + mu.Lock() + ensureDefaultConfigLocked() + cfg := defaults + mu.Unlock() + carrierName = normalizeCarrier(carrierName) transportName = normalizeTransport(transportName) if err := validateStartArgs(carrierName, roomID, clientID, keyHex); err != nil { @@ -227,6 +251,7 @@ func Check( DNSServer: defaultDNSServer, VP8FPS: clampAtLeastOne(vp8FPS, 120), VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64), + Liveness: livenessConfig(cfg), }, func() { readyOnce.Do(func() { @@ -271,6 +296,11 @@ func Ping( vp8BatchSize int, ) (int64, error) { registerDefaults() + mu.Lock() + ensureDefaultConfigLocked() + cfg := defaults + mu.Unlock() + carrierName = normalizeCarrier(carrierName) transportName = normalizeTransport(transportName) @@ -310,6 +340,7 @@ func Ping( DNSServer: defaultDNSServer, VP8FPS: clampAtLeastOne(vp8FPS, 120), VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64), + Liveness: livenessConfig(cfg), }, func() { readyOnce.Do(func() { @@ -557,6 +588,7 @@ func startWithConfig( SOCKSPass: socksPass, VP8FPS: cfg.vp8FPS, VP8BatchSize: cfg.vp8BatchSize, + Liveness: livenessConfig(cfg), }, func() { readyOnce.Do(func() { @@ -576,6 +608,7 @@ func startWithConfig( } // WaitReady blocks until the selected transport is connected and the local SOCKS5 listener is ready. +// //nolint:cyclop // straightforward state-machine waits with multiple terminal conditions func WaitReady(timeoutMillis int) error { mu.Lock() @@ -666,15 +699,38 @@ func waitForCheckDone(doneCh <-chan error) { func ensureDefaultConfigLocked() { defaultsSet.Do(func() { defaults = mobileConfig{ - link: defaultLink, - transport: defaultTransport, - dnsServer: defaultDNSServer, - vp8FPS: 60, - vp8BatchSize: 8, + link: defaultLink, + transport: defaultTransport, + dnsServer: defaultDNSServer, + vp8FPS: 60, + vp8BatchSize: 8, + livenessInterval: control.DefaultInterval, + livenessTimeout: control.DefaultTimeout, + livenessFailures: control.DefaultFailures, } }) } +func livenessConfig(cfg mobileConfig) control.Config { + interval := cfg.livenessInterval + if interval <= 0 { + interval = control.DefaultInterval + } + timeout := cfg.livenessTimeout + if timeout <= 0 { + timeout = control.DefaultTimeout + } + failures := cfg.livenessFailures + if failures <= 0 { + failures = control.DefaultFailures + } + return control.Config{ + Interval: interval, + Timeout: timeout, + Failures: failures, + } +} + func normalizeTransport(value string) string { switch value { case dataTransport, "data", "dc": @@ -734,6 +790,17 @@ func clampAtLeastOne(value, maxValue int) int { return value } +func durationFromMillisOrDefault(value int, def time.Duration) time.Duration { + if value <= 0 { + return def + } + d := time.Duration(value) * time.Millisecond + if d <= 0 { + return def + } + return d +} + // logBridge adapts LogWriter to io.Writer. type logBridge struct { w LogWriter diff --git a/mobile/mobile_test.go b/mobile/mobile_test.go index f22625b..2498103 100644 --- a/mobile/mobile_test.go +++ b/mobile/mobile_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/protect" ) @@ -83,12 +84,15 @@ func TestDefaultsAndSetters(t *testing.T) { SetLink("direct") SetDNS("9.9.9.9:53") SetVP8Options(-1, 999) + SetLivenessOptions(2500, 750, -1) mu.Lock() got := defaults mu.Unlock() if got.transport != dataTransport || got.link != defaultLink || got.dnsServer != "9.9.9.9:53" || - got.vp8FPS != 1 || got.vp8BatchSize != 64 { + got.vp8FPS != 1 || got.vp8BatchSize != 64 || + got.livenessInterval != 2500*time.Millisecond || got.livenessTimeout != 750*time.Millisecond || + got.livenessFailures != control.DefaultFailures { t.Fatalf("defaults = %+v", got) } @@ -168,15 +172,19 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) { t.Cleanup(func() { resetMobileGlobals(t) }) + SetLivenessOptions(2500, 750, 4) runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error { if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz || cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" || - cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 { + cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 || + cfg.Liveness.Interval != 2500*time.Millisecond || + cfg.Liveness.Timeout != 750*time.Millisecond || + cfg.Liveness.Failures != 4 { t.Fatalf( - "RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d", + "RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d liveness=%+v", cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID, - cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize, + cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness, ) } onReady() @@ -208,9 +216,12 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) { runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error { if cfg.Transport != defaultTransport || cfg.RoomURL != "https://telemost.yandex.ru/j/room" || - cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" { - t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q", - cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass) + cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" || + cfg.Liveness.Interval != control.DefaultInterval || + cfg.Liveness.Timeout != control.DefaultTimeout || + cfg.Liveness.Failures != control.DefaultFailures { + t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q liveness=%+v", + cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass, cfg.Liveness) } onReady() <-ctx.Done() @@ -225,9 +236,14 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) { } Stop() + SetLivenessOptions(3000, 1000, 5) runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error { - if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 { - t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d", cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize) + if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 || + cfg.Liveness.Interval != 3000*time.Millisecond || + cfg.Liveness.Timeout != time.Second || + cfg.Liveness.Failures != 5 { + t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d liveness=%+v", + cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness) } onReady() <-ctx.Done() @@ -242,6 +258,32 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) { } } +func TestPingPassesLiveness(t *testing.T) { + resetMobileGlobals(t) + t.Cleanup(func() { + resetMobileGlobals(t) + }) + SetLivenessOptions(4000, 1500, 6) + + seen := make(chan control.Config, 1) + runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error { + seen <- cfg.Liveness + onReady() + <-ctx.Done() + return nil + } + + _, _ = Ping("jazz", "dc", "", "client", "key", 1085, 100, "http://127.0.0.1/", 30, 1) + select { + case got := <-seen: + if got.Interval != 4000*time.Millisecond || got.Timeout != 1500*time.Millisecond || got.Failures != 6 { + t.Fatalf("Ping liveness = %+v", got) + } + default: + t.Fatal("Ping did not start client") + } +} + func TestCheckTimeoutAndRunError(t *testing.T) { resetMobileGlobals(t) t.Cleanup(func() { From 4c6bd2b838077c4686a2bff627e9d3649aff8eb6 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:34:39 +0300 Subject: [PATCH 4/8] feat: expose control health status --- docs/project-map.md | 7 ++- internal/client/client.go | 90 ++++++++++++++++++++++++++++++-- internal/client/client_test.go | 26 +++++++++ internal/control/control.go | 24 ++++++++- internal/control/control_test.go | 16 ++++-- internal/server/server.go | 87 +++++++++++++++++++++++++++++- internal/server/server_test.go | 26 +++++++++ 7 files changed, 266 insertions(+), 10 deletions(-) diff --git a/docs/project-map.md b/docs/project-map.md index e1b2134..4481982 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -164,6 +164,11 @@ Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and `liveness.failures: 3`. Missed pongs mark the smux session unhealthy and trigger a session rebuild/reconnect path. +Client and server runtimes also maintain a `control.Status` snapshot with +session ID, last pong time, RTT, missed pongs, reconnect count, and unhealthy +event count. Embedders can consume it through the client/server health +callbacks. + ## Registries And Plugin Shape The universal-carrier refactor centers on small registries: @@ -339,7 +344,7 @@ still the natural place for: - Server policy updates. - Graceful reconnect notifications. - Drain/start markers for failover. -- Per-session stats. +- More per-session stats. Likely files: diff --git a/internal/client/client.go b/internal/client/client.go index 13be135..001cb4c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -58,6 +58,9 @@ type Client struct { controlStop context.CancelFunc sessMu sync.RWMutex reconnectMu sync.Mutex + healthMu sync.RWMutex + health control.Status + onHealth HealthFunc deviceID string sessionID string claims map[string]any @@ -66,6 +69,9 @@ type Client struct { socksPass string } +// HealthFunc is called when the client control health snapshot changes. +type HealthFunc func(control.Status) + // Config holds runtime configuration for [Run] and [RunWithReady]. type Config struct { Link string @@ -110,6 +116,9 @@ type Config struct { // Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to // the server's AuthHook. Free-form key/value bag for plan, user, region, etc. Claims map[string]any + + // OnHealth receives liveness/reconnect status updates. Nil means no-op. + OnHealth HealthFunc } // Run starts the client with the given configuration. @@ -139,6 +148,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error { dnsServer: cfg.DNSServer, socksUser: cfg.SOCKSUser, socksPass: cfg.SOCKSPass, + onHealth: cfg.OnHealth, } // shutdown is registered BEFORE bringUpLink so we always close any @@ -221,7 +231,7 @@ func (c *Client) bringUpLink( if ctx.Err() != nil { return } - if !c.handleReconnect(ctx, cfg, cancel) { + if !c.handleReconnect(ctx, cfg, cancel, "carrier") { cancel() } }) @@ -249,6 +259,7 @@ func (c *Client) bringUpLink( c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.recordSession(sid) c.startControlLoop(ctx, cfg, cancel, control) go ln.WatchConnection(ctx) @@ -333,11 +344,12 @@ func smuxConfig() *smux.Config { return cfg } -func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc) bool { +func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool { c.reconnectMu.Lock() defer c.reconnectMu.Unlock() - logger.Infof("client link reconnect - tearing down smux session") + c.recordReconnect() + logger.Infof("client reconnect reason=%s - tearing down smux session", reason) // Install a fresh muxconn immediately so onData never hits nil while // the old session is being torn down. tryReopenSession will swap it @@ -379,6 +391,7 @@ func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context attemptDelay = 300 * time.Millisecond ) for attempt := 1; attempt <= maxAttempts; attempt++ { + logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason) if c.tryReopenSession(ctx, cfg, cancel, attempt) { return true } @@ -425,6 +438,7 @@ func (c *Client) tryReopenSession( c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.recordSession(sid) c.startControlLoop(ctx, cfg, cancel, control) return true } @@ -442,17 +456,27 @@ func (c *Client) startControlLoop( liveness := cfg.Liveness onPong := liveness.OnPong + onMissedPong := liveness.OnMissedPong onUnhealthy := liveness.OnUnhealthy liveness.OnPong = func(h control.Health) { c.sessMu.RLock() sid := c.sessionID c.sessMu.RUnlock() + c.recordPong(h) logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) if onPong != nil { onPong(h) } } + liveness.OnMissedPong = func(missed int) { + c.recordMissed(missed) + logger.Warnf("control missed pong on client: missed_pongs=%d", missed) + if onMissedPong != nil { + onMissedPong(missed) + } + } liveness.OnUnhealthy = func(missed int) { + c.recordUnhealthy(missed) logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed) if onUnhealthy != nil { onUnhealthy(missed) @@ -467,12 +491,70 @@ func (c *Client) startControlLoop( if err != nil { logger.Warnf("client control stream ended: %v", err) } - if !c.handleReconnect(ctx, cfg, cancel) { + if !c.handleReconnect(ctx, cfg, cancel, "liveness") { cancel() } }() } +// Status returns the latest client-side control health snapshot. +func (c *Client) Status() control.Status { + c.healthMu.RLock() + defer c.healthMu.RUnlock() + return c.health +} + +func (c *Client) recordSession(sessionID string) { + c.healthMu.Lock() + c.health.SessionID = sessionID + c.health.MissedPongs = 0 + status := c.health + c.healthMu.Unlock() + c.notifyHealth(status) +} + +func (c *Client) recordPong(h control.Health) { + c.healthMu.Lock() + c.health.LastPong = h.LastSeen + c.health.LastRTT = h.RTT + c.health.MissedPongs = 0 + status := c.health + c.healthMu.Unlock() + c.notifyHealth(status) +} + +func (c *Client) recordMissed(missed int) { + c.healthMu.Lock() + c.health.MissedPongs = missed + status := c.health + c.healthMu.Unlock() + c.notifyHealth(status) +} + +func (c *Client) recordUnhealthy(missed int) { + c.healthMu.Lock() + c.health.MissedPongs = missed + c.health.UnhealthyEvents++ + c.health.LastUnhealthy = time.Now() + status := c.health + c.healthMu.Unlock() + c.notifyHealth(status) +} + +func (c *Client) recordReconnect() { + c.healthMu.Lock() + c.health.Reconnects++ + status := c.health + c.healthMu.Unlock() + c.notifyHealth(status) +} + +func (c *Client) notifyHealth(status control.Status) { + if c.onHealth != nil { + c.onHealth(status) + } +} + func (c *Client) shutdown() { c.sessMu.Lock() control := c.controlStrm diff --git a/internal/client/client_test.go b/internal/client/client_test.go index f5d836b..82d0099 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -555,6 +555,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { defer cancel() got := make(chan control.Health, 1) c := &Client{sessionID: "sid-control"} + c.recordSession("sid-control") c.startControlLoop(ctx, Config{ Liveness: control.Config{ Interval: 10 * time.Millisecond, @@ -584,4 +585,29 @@ func TestStartControlLoopReportsPong(t *testing.T) { case <-time.After(time.Second): t.Fatal("timed out waiting for control pong") } + status := c.Status() + if status.SessionID != "sid-control" { + t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID) + } + if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 { + t.Fatalf("Status() = %+v", status) + } +} + +func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) { + updates := 0 + c := &Client{onHealth: func(control.Status) { updates++ }} + c.recordSession("sid-1") + c.recordMissed(2) + c.recordUnhealthy(3) + c.recordReconnect() + + status := c.Status() + if status.SessionID != "sid-1" || status.MissedPongs != 3 || + status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() { + t.Fatalf("Status() = %+v", status) + } + if updates != 4 { + t.Fatalf("health updates = %d, want 4", updates) + } } diff --git a/internal/control/control.go b/internal/control/control.go index a6bd50f..d799518 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -71,6 +71,18 @@ type Health struct { LastSeen time.Time } +// Status is a point-in-time view of control-stream health maintained by +// callers that embed the control loop. +type Status struct { + SessionID string + LastPong time.Time + LastRTT time.Duration + MissedPongs int + Reconnects uint64 + UnhealthyEvents uint64 + LastUnhealthy time.Time +} + // Config controls the liveness loop. type Config struct { Interval time.Duration @@ -79,6 +91,8 @@ type Config struct { // OnPong is called after a matching pong is received. OnPong func(Health) + // OnMissedPong is called when one or more outstanding pongs time out. + OnMissedPong func(missed int) // OnUnhealthy is called before Run returns [ErrUnhealthy]. OnUnhealthy func(missed int) } @@ -195,16 +209,21 @@ func (s *state) sendProbe(ctx context.Context) error { now := s.now() s.mu.Lock() + missedNow := 0 for seq, sent := range s.pending { if now.Sub(sent) < s.cfg.Timeout { continue } delete(s.pending, seq) s.failures++ + missedNow++ } + missed := s.failures if s.failures >= s.cfg.Failures { - missed := s.failures s.mu.Unlock() + if missedNow > 0 && s.cfg.OnMissedPong != nil { + s.cfg.OnMissedPong(missed) + } if s.cfg.OnUnhealthy != nil { s.cfg.OnUnhealthy(missed) } @@ -215,6 +234,9 @@ func (s *state) sendProbe(ctx context.Context) error { seq := s.nextSeq s.pending[seq] = now s.mu.Unlock() + if missedNow > 0 && s.cfg.OnMissedPong != nil { + s.cfg.OnMissedPong(missed) + } return s.enqueue(ctx, Message{ Version: ProtoVersion, diff --git a/internal/control/control_test.go b/internal/control/control_test.go index 3c52bf6..8700027 100644 --- a/internal/control/control_test.go +++ b/internal/control/control_test.go @@ -71,12 +71,19 @@ func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) { }() missedCh := make(chan int, 1) + missedCallbackCh := make(chan int, 1) errCh := make(chan error, 1) go func() { errCh <- Run(ctx, a, Config{ - Interval: 10 * time.Millisecond, - Timeout: 5 * time.Millisecond, - Failures: 2, + Interval: 10 * time.Millisecond, + Timeout: 5 * time.Millisecond, + Failures: 2, + OnMissedPong: func(missed int) { + select { + case missedCallbackCh <- missed: + default: + } + }, OnUnhealthy: func(missed int) { missedCh <- missed }, }) }() @@ -92,6 +99,9 @@ func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) { if missed := <-missedCh; missed < 2 { t.Fatalf("missed = %d, want >= 2", missed) } + if missed := <-missedCallbackCh; missed < 1 { + t.Fatalf("missed callback = %d, want >= 1", missed) + } } func TestRunRejectsBadProtocolVersion(t *testing.T) { diff --git a/internal/server/server.go b/internal/server/server.go index 4954ad4..7dae4eb 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -50,6 +50,9 @@ type SessionCloseFunc func(sessionID, reason string) // bytesIn counts client→target bytes; bytesOut counts target→client bytes. type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64) +// HealthFunc is called when the server control health snapshot changes. +type HealthFunc func(control.Status) + // Server handles incoming tunnel connections and proxies their traffic. type Server struct { ln link.Link @@ -59,11 +62,13 @@ type Server struct { controlStop context.CancelFunc sessMu sync.RWMutex reinstallMu sync.Mutex + healthMu sync.RWMutex wg sync.WaitGroup authHook handshake.AuthFunc onOpen SessionOpenFunc onClose SessionCloseFunc onTraffic TrafficFunc + onHealth HealthFunc deviceID string sessionID string dnsServer string @@ -71,6 +76,7 @@ type Server struct { socksProxyAddr string socksProxyPort int liveness control.Config + health control.Status } // ConnectRequest is a message from the client to establish a new connection. @@ -121,6 +127,8 @@ type Config struct { OnSessionClose SessionCloseFunc // OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op. OnTraffic TrafficFunc + // OnHealth fires when liveness/reconnect status changes. Nil means no-op. + OnHealth HealthFunc } // Run starts the server with the given configuration. @@ -149,6 +157,10 @@ func Run(ctx context.Context, cfg Config) error { if onTraffic == nil { onTraffic = func(string, string, uint64, uint64) {} } + onHealth := cfg.OnHealth + if onHealth == nil { + onHealth = func(control.Status) {} + } s := &Server{ cipher: cipher, @@ -156,6 +168,7 @@ func Run(ctx context.Context, cfg Config) error { onOpen: onOpen, onClose: onClose, onTraffic: onTraffic, + onHealth: onHealth, dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, @@ -315,7 +328,8 @@ func (s *Server) installSession() { } func (s *Server) handleReconnect() { - logger.Infof("server link reconnect - tearing down smux session") + s.recordReconnect() + logger.Infof("server reconnect reason=carrier - tearing down smux session") s.sessMu.RLock() current := s.session s.sessMu.RUnlock() @@ -491,6 +505,7 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { s.deviceID = hello.DeviceID s.sessionID = sid s.sessMu.Unlock() + s.recordSession(sid) s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) s.startControlLoop(ctx, sess, stream) @@ -505,17 +520,27 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea liveness := s.liveness onPong := liveness.OnPong + onMissedPong := liveness.OnMissedPong onUnhealthy := liveness.OnUnhealthy liveness.OnPong = func(h control.Health) { s.sessMu.RLock() sid := s.sessionID s.sessMu.RUnlock() + s.recordPong(h) logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) if onPong != nil { onPong(h) } } + liveness.OnMissedPong = func(missed int) { + s.recordMissed(missed) + logger.Warnf("control missed pong on server: missed_pongs=%d", missed) + if onMissedPong != nil { + onMissedPong(missed) + } + } liveness.OnUnhealthy = func(missed int) { + s.recordUnhealthy(missed) logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed) if onUnhealthy != nil { onUnhealthy(missed) @@ -533,10 +558,70 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea if err != nil { logger.Warnf("server control stream ended: %v", err) } + s.recordReconnect() + logger.Infof("server reconnect reason=liveness - reinstalling smux session") s.reinstallSession(sess) }() } +// Status returns the latest server-side control health snapshot. +func (s *Server) Status() control.Status { + s.healthMu.RLock() + defer s.healthMu.RUnlock() + return s.health +} + +func (s *Server) recordSession(sessionID string) { + s.healthMu.Lock() + s.health.SessionID = sessionID + s.health.MissedPongs = 0 + status := s.health + s.healthMu.Unlock() + s.notifyHealth(status) +} + +func (s *Server) recordPong(h control.Health) { + s.healthMu.Lock() + s.health.LastPong = h.LastSeen + s.health.LastRTT = h.RTT + s.health.MissedPongs = 0 + status := s.health + s.healthMu.Unlock() + s.notifyHealth(status) +} + +func (s *Server) recordMissed(missed int) { + s.healthMu.Lock() + s.health.MissedPongs = missed + status := s.health + s.healthMu.Unlock() + s.notifyHealth(status) +} + +func (s *Server) recordUnhealthy(missed int) { + s.healthMu.Lock() + s.health.MissedPongs = missed + s.health.UnhealthyEvents++ + s.health.LastUnhealthy = time.Now() + status := s.health + s.healthMu.Unlock() + s.notifyHealth(status) +} + +func (s *Server) recordReconnect() { + s.healthMu.Lock() + s.health.Reconnects++ + status := s.health + s.healthMu.Unlock() + s.notifyHealth(status) +} + +func (s *Server) notifyHealth(status control.Status) { + if s.onHealth != nil { + s.onHealth(status) + } +} + func (s *Server) shutdown() { s.closeSession() if s.ln != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index d5a6f6d..dc80b21 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -422,6 +422,7 @@ func TestStartControlLoopReportsPong(t *testing.T) { }, }, } + s.recordSession("sid-control") defer func() { cancel() s.wg.Wait() @@ -443,6 +444,31 @@ func TestStartControlLoopReportsPong(t *testing.T) { case <-time.After(time.Second): t.Fatal("timed out waiting for control pong") } + status := s.Status() + if status.SessionID != "sid-control" { + t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID) + } + if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 { + t.Fatalf("Status() = %+v", status) + } +} + +func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) { + updates := 0 + s := &Server{onHealth: func(control.Status) { updates++ }} + s.recordSession("sid-1") + s.recordMissed(2) + s.recordUnhealthy(3) + s.recordReconnect() + + status := s.Status() + if status.SessionID != "sid-1" || status.MissedPongs != 3 || + status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() { + t.Fatalf("Status() = %+v", status) + } + if updates != 4 { + t.Fatalf("health updates = %d, want 4", updates) + } } //nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together. From 82b5741ab1ea6896c035a22331b9f0ba049f8f79 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:49:52 +0300 Subject: [PATCH 5/8] feat: add planned session rotation --- docs/client.example.yaml | 4 + docs/configuration.md | 19 ++++ docs/failover.example.yaml | 4 + docs/project-map.md | 1 + docs/server.example.yaml | 4 + docs/settings.md | 8 ++ internal/app/session/session.go | 163 +++++++++++++++++++++------ internal/app/session/session_test.go | 53 +++++++++ internal/config/config.go | 69 +++++++----- internal/config/config_test.go | 49 ++++---- 10 files changed, 288 insertions(+), 86 deletions(-) diff --git a/docs/client.example.yaml b/docs/client.example.yaml index a074a6a..06b9b5e 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -26,6 +26,10 @@ liveness: timeout: 5s failures: 3 +# Optional planned rebuild for long-running calls. +# lifecycle: +# max_session_duration: 6h + # Local SOCKS5 listener exposed to applications socks: host: "127.0.0.1" diff --git a/docs/configuration.md b/docs/configuration.md index 8c067ad..41bdeaa 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,6 +34,7 @@ olcrtc /etc/olcrtc/server.yaml | `liveness.interval` | control-stream ping interval, default `10s` | | `liveness.timeout` | pong timeout, default `5s` | | `liveness.failures` | missed pongs before reconnect, default `3` | +| `lifecycle.max_session_duration` | planned session rebuild interval, e.g. `6h`; unset = off | | `gen.amount` | gen mode: number of rooms to create | | `profiles[]` | ordered srv/cnc failover profiles | | `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | @@ -67,6 +68,24 @@ When the failure threshold is reached, the current smux session is rebuilt. In failover mode, a profile that exits after liveness-triggered reconnect failure lets the supervisor advance to the next profile. +## Lifecycle Rotation + +`lifecycle.max_session_duration` sets a planned upper bound for one provider +call/session. When the duration expires, olcrtc cancels the active server or +client session and starts a fresh one with the same config. While this option +is enabled, clean session endings are also restarted so the peer that did not +fire the timer can follow the rebuild. This is useful for long-running +deployments where provider calls get stale, accumulate media state, or should +be periodically re-created. + +```yaml +lifecycle: + max_session_duration: 6h +``` + +The field is optional and disabled when omitted. Values use Go duration syntax +such as `30m`, `2h`, or `6h`; zero and negative durations are rejected. + ## Failover Profiles `mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml index e956a35..298a847 100644 --- a/docs/failover.example.yaml +++ b/docs/failover.example.yaml @@ -15,6 +15,10 @@ liveness: timeout: 5s failures: 3 +# Optional planned rebuild for each active profile. +# lifecycle: +# max_session_duration: 6h + data: data profiles: diff --git a/docs/project-map.md b/docs/project-map.md index 4481982..55fd291 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -304,6 +304,7 @@ Implemented: - `failover.retry_delay`. - `failover.max_cycles`. - Profile start/end logs. +- Planned session rotation with `lifecycle.max_session_duration`. Still valuable: diff --git a/docs/server.example.yaml b/docs/server.example.yaml index c20b1e5..300f7cf 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -28,6 +28,10 @@ liveness: timeout: 5s failures: 3 +# Optional planned rebuild for long-running calls. +# lifecycle: +# max_session_duration: 6h + # Outbound SOCKS5 proxy for server-side egress (optional) socks: proxy_addr: "" # e.g. "127.0.0.1" diff --git a/docs/settings.md b/docs/settings.md index 2e2d78a..9f9d215 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -66,6 +66,7 @@ | `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` | | `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` | | `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` | +| `lifecycle.max_session_duration` | Плановый rebuild сессии после указанного времени, например `6h`; если поле не задано, выключено | `crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. @@ -78,6 +79,13 @@ а не только статус WebRTC/provider соединения. Если pong не приходит несколько раз подряд, текущая smux-сессия пересоздается. +`lifecycle.max_session_duration` ограничивает длительность одного звонка / +provider session. Когда таймер истекает, текущая `srv` или `cnc` сессия +закрывается и стартует заново с тем же конфигом. Пока эта настройка включена, +чистое завершение сессии тоже перезапускается, чтобы второй peer мог догнать +плановый rebuild. Формат значения: `30m`, `2h`, `6h`; `0s` и отрицательные +значения не принимаются. + --- ## mode: gen diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 360d96a..0b48f50 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "slices" + "sync/atomic" "time" "github.com/openlibrecommunity/olcrtc/internal/auth" @@ -54,6 +55,8 @@ const ( defaultSEIAckTimeoutMS = 2000 ) +var sessionRestartDelay = 2 * time.Second + var ( // ErrRoomIDRequired indicates that a room id is required for the selected carrier. ErrRoomIDRequired = errors.New("room ID required (set room.id)") @@ -131,46 +134,50 @@ var ( // ErrLivenessFailuresInvalid indicates that liveness.failures is not positive. ErrLivenessFailuresInvalid = errors.New( "invalid liveness failures (set liveness.failures to a value > 0)") + // ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration. + ErrLifecycleMaxSessionDurationInvalid = errors.New( + "invalid max session duration (set lifecycle.max_session_duration to a duration > 0)") ) // Config holds runtime session settings. type Config struct { - Mode string - Link string - Transport string - Auth string - Engine string - URL string - Token string - RoomID string - KeyHex string - SOCKSHost string - SOCKSPort int - SOCKSUser string - SOCKSPass string - DNSServer string - SOCKSProxyAddr string - SOCKSProxyPort int - VideoWidth int - VideoHeight int - VideoFPS int - VideoBitrate string - VideoHW string - VideoQRSize int - VideoQRRecovery string - VideoCodec string - VideoTileModule int - VideoTileRS int - VP8FPS int - VP8BatchSize int - SEIFPS int - SEIBatchSize int - SEIFragmentSize int - SEIAckTimeoutMS int - LivenessInterval string - LivenessTimeout string - LivenessFailures int - Amount int + Mode string + Link string + Transport string + Auth string + Engine string + URL string + Token string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + SOCKSUser string + SOCKSPass string + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + LivenessInterval string + LivenessTimeout string + LivenessFailures int + MaxSessionDuration string + Amount int } // RegisterDefaults registers built-in carriers and transports. @@ -323,6 +330,9 @@ func Validate(cfg Config) error { if err := validateLivenessConfig(cfg); err != nil { return err } + if err := validateLifecycleConfig(cfg); err != nil { + return err + } return validateModeConfig(cfg) } @@ -475,6 +485,13 @@ func validateLivenessConfig(cfg Config) error { return nil } +func validateLifecycleConfig(cfg Config) error { + if _, err := maxSessionDuration(cfg); err != nil { + return err + } + return nil +} + func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) { if value == "" { return def, nil @@ -508,6 +525,20 @@ func livenessConfig(cfg Config) (control.Config, error) { return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil } +func maxSessionDuration(cfg Config) (time.Duration, error) { + if cfg.MaxSessionDuration == "" { + return 0, nil + } + d, err := time.ParseDuration(cfg.MaxSessionDuration) + if err != nil { + return 0, fmt.Errorf("%w: %v", ErrLifecycleMaxSessionDurationInvalid, err) + } + if d <= 0 { + return 0, ErrLifecycleMaxSessionDurationInvalid + } + return d, nil +} + func isLoopbackListenHost(host string) bool { if host == "localhost" { return true @@ -525,7 +556,21 @@ func Run(ctx context.Context, cfg Config) error { if err != nil { return err } + maxDuration, err := maxSessionDuration(cfg) + if err != nil { + return err + } + run := func(ctx context.Context) error { + return runOnce(ctx, cfg, roomURL, liveness) + } + if maxDuration > 0 { + return runWithSessionRotation(ctx, maxDuration, run) + } + return run(ctx) +} + +func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.Config) error { switch cfg.Mode { case modeSRV: if err := server.Run(ctx, server.Config{ @@ -610,6 +655,52 @@ func Run(ctx context.Context, cfg Config) error { } } +func runWithSessionRotation(ctx context.Context, maxDuration time.Duration, run func(context.Context) error) error { + for cycle := 1; ; cycle++ { + currentCycle := cycle + runCtx, cancel := context.WithCancel(ctx) + var rotated atomic.Bool + timer := time.AfterFunc(maxDuration, func() { + rotated.Store(true) + logger.Infof("session max duration reached: duration=%s cycle=%d", maxDuration, currentCycle) + cancel() + }) + + err := run(runCtx) + cancel() + timer.Stop() + if ctx.Err() != nil { + return nil + } + if rotated.Load() { + if err != nil { + logger.Warnf("session rotation ended with error: cycle=%d err=%v", currentCycle, err) + } + logger.Infof("session rotation restarting: next_cycle=%d", currentCycle+1) + if err := waitSessionRestart(ctx); err != nil { + return nil + } + continue + } + if err != nil { + return err + } + logger.Infof("session ended cleanly with lifecycle rotation enabled: next_cycle=%d", currentCycle+1) + if err := waitSessionRestart(ctx); err != nil { + return nil + } + } +} + +func waitSessionRestart(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sessionRestartDelay): + return nil + } +} + // ValidateGen validates that the config contains enough fields to run gen mode. func ValidateGen(cfg Config) error { if cfg.Auth == "" { diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index 95270b2..5fc219d 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -3,7 +3,9 @@ package session import ( "context" "errors" + "sync/atomic" "testing" + "time" "github.com/openlibrecommunity/olcrtc/internal/control" ) @@ -105,6 +107,31 @@ func TestApplyLivenessDefaults(t *testing.T) { } } +func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) { + oldRestartDelay := sessionRestartDelay + sessionRestartDelay = time.Millisecond + t.Cleanup(func() { sessionRestartDelay = oldRestartDelay }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls atomic.Int32 + err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error { + if calls.Add(1) >= 2 { + cancel() + return nil + } + <-ctx.Done() + return nil + }) + if err != nil { + t.Fatalf("runWithSessionRotation() error = %v", err) + } + if got := calls.Load(); got < 2 { + t.Fatalf("run calls = %d, want at least 2", got) + } +} + //nolint:maintidx // table-driven validation test naturally has many cases func TestValidate(t *testing.T) { RegisterDefaults() @@ -469,6 +496,32 @@ func TestValidate(t *testing.T) { }(), want: ErrLivenessFailuresInvalid, }, + { + name: "lifecycle accepts max session duration", + cfg: func() Config { + cfg := base + cfg.MaxSessionDuration = "1h" + return cfg + }(), + }, + { + name: "lifecycle rejects bad max session duration", + cfg: func() Config { + cfg := base + cfg.MaxSessionDuration = "nope" + return cfg + }(), + want: ErrLifecycleMaxSessionDurationInvalid, + }, + { + name: "lifecycle rejects zero max session duration", + cfg: func() Config { + cfg := base + cfg.MaxSessionDuration = "0s" + return cfg + }(), + want: ErrLifecycleMaxSessionDurationInvalid, + }, } for _, tt := range tests { diff --git a/internal/config/config.go b/internal/config/config.go index 9524363..770adf5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,40 +30,42 @@ var ( // File is the on-disk YAML schema. type File struct { - Mode string `yaml:"mode"` - Link string `yaml:"link"` - Auth Auth `yaml:"auth"` - Room Room `yaml:"room"` - Crypto Crypto `yaml:"crypto"` - Net Net `yaml:"net"` - SOCKS SOCKS `yaml:"socks"` - Engine Engine `yaml:"engine"` - Video Video `yaml:"video"` - VP8 VP8 `yaml:"vp8"` - SEI SEI `yaml:"sei"` - Liveness Liveness `yaml:"liveness"` - Gen Gen `yaml:"gen"` - Profiles []Profile `yaml:"profiles"` - Failover Failover `yaml:"failover"` - Data string `yaml:"data"` - Debug bool `yaml:"debug"` - FFmpeg string `yaml:"ffmpeg"` + Mode string `yaml:"mode"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` + Lifecycle Lifecycle `yaml:"lifecycle"` + Gen Gen `yaml:"gen"` + Profiles []Profile `yaml:"profiles"` + Failover Failover `yaml:"failover"` + Data string `yaml:"data"` + Debug bool `yaml:"debug"` + FFmpeg string `yaml:"ffmpeg"` } // Profile is a failover entry that overrides top-level runtime fields. type Profile struct { - Name string `yaml:"name"` - Link string `yaml:"link"` - Auth Auth `yaml:"auth"` - Room Room `yaml:"room"` - Crypto Crypto `yaml:"crypto"` - Net Net `yaml:"net"` - SOCKS SOCKS `yaml:"socks"` - Engine Engine `yaml:"engine"` - Video Video `yaml:"video"` - VP8 VP8 `yaml:"vp8"` - SEI SEI `yaml:"sei"` - Liveness Liveness `yaml:"liveness"` + Name string `yaml:"name"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` + Lifecycle Lifecycle `yaml:"lifecycle"` } // Failover controls ordered profile failover. @@ -146,6 +148,11 @@ type Liveness struct { Failures int `yaml:"failures"` } +// Lifecycle controls planned session rebuilds. +type Lifecycle struct { + MaxSessionDuration string `yaml:"max_session_duration"` +} + // Gen controls room-generation mode. type Gen struct { Amount int `yaml:"amount"` @@ -260,6 +267,7 @@ func Apply(dst session.Config, f File) session.Config { dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval) dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout) dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures) + dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration) dst.Amount = pickInt(dst.Amount, f.Gen.Amount) return dst } @@ -301,6 +309,7 @@ func ApplyProfile(base session.Config, p Profile) session.Config { dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval) dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout) dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures) + dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration) return dst } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b41604c..06d1406 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -43,6 +43,8 @@ liveness: interval: 2s timeout: 500ms failures: 4 +lifecycle: + max_session_duration: 6h gen: amount: 3 debug: true @@ -80,23 +82,24 @@ func requireLoadedFile(t *testing.T, f File) { func requireAppliedConfig(t *testing.T, got session.Config) { t.Helper() want := session.Config{ - Mode: testModeSrv, - Link: "direct", - Auth: testAuthProvider, - RoomID: testRoomID, - KeyHex: testCryptoKey, - Transport: "datachannel", - DNSServer: "1.1.1.1:53", - SOCKSHost: "127.0.0.1", - SOCKSPort: 1080, - SOCKSUser: "u", - SOCKSPass: "p", - VP8FPS: 25, - VP8BatchSize: 4, - LivenessInterval: "2s", - LivenessTimeout: "500ms", - LivenessFailures: 4, - Amount: 3, + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + LivenessInterval: "2s", + LivenessTimeout: "500ms", + LivenessFailures: 4, + MaxSessionDuration: "6h", + Amount: 3, } if got != want { t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) @@ -143,6 +146,8 @@ liveness: interval: 5s timeout: 2s failures: 5 +lifecycle: + max_session_duration: 6h profiles: - name: wb-vp8 auth: @@ -155,6 +160,8 @@ profiles: fps: 30 liveness: interval: 1s + lifecycle: + max_session_duration: 30m - name: jitsi-dc auth: provider: jitsi @@ -188,7 +195,8 @@ failover: t.Fatalf("first profile = %+v", first) } if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 || - first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 { + first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 || + first.MaxSessionDuration != "30m" { t.Fatalf("first inherited/overlaid fields = %+v", first) } second := ApplyProfile(base, f.Profiles[1]) @@ -196,8 +204,9 @@ failover: second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" { t.Fatalf("second profile = %+v", second) } - if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 { - t.Fatalf("second liveness fields = %+v", second) + if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 || + second.MaxSessionDuration != "6h" { + t.Fatalf("second lifecycle/liveness fields = %+v", second) } } From b0aee57aa5aa8a34d0b2769a18b319d13107d1de Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:53:00 +0300 Subject: [PATCH 6/8] feat: track failover supervisor status --- cmd/olcrtc/main.go | 30 ++++++ docs/configuration.md | 4 + docs/project-map.md | 3 +- internal/supervisor/supervisor.go | 133 +++++++++++++++++++++++++ internal/supervisor/supervisor_test.go | 85 ++++++++++++++++ 5 files changed, 254 insertions(+), 1 deletion(-) diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index af7b87f..45662af 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -214,10 +214,40 @@ func runFailoverSessionMode(dataDir string, profiles []supervisor.Profile, failo } logger.Warnf("failover cycle=%d profile=%s ended", cycle, profile.Name) }, + OnStatus: logFailoverStatus, }, runSession) }) } +func logFailoverStatus(status supervisor.Status) { + if !logger.IsVerbose() { + return + } + active := status.ActiveProfile + if active == "" { + active = "none" + } + logger.Debugf("failover status cycle=%d active=%s last_error=%q profiles=%s history=%d", + status.Cycle, active, status.LastError, formatProfileStatuses(status.Profiles), len(status.History)) +} + +func formatProfileStatuses(profiles []supervisor.ProfileStatus) string { + if len(profiles) == 0 { + return "[]" + } + var buf bytes.Buffer + buf.WriteByte('[') + for i, profile := range profiles { + if i > 0 { + buf.WriteByte(' ') + } + fmt.Fprintf(&buf, "%s{starts=%d failures=%d clean=%d}", + profile.Name, profile.Starts, profile.Failures, profile.CleanEnds) + } + buf.WriteByte(']') + return buf.String() +} + func prepareRuntimeData(dataDir string) error { if dataDir == "" { return ErrDataDirRequired diff --git a/docs/configuration.md b/docs/configuration.md index 41bdeaa..52123f1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -127,3 +127,7 @@ failover: Both peers must use compatible profile order and room settings. This first failover layer rebuilds the session on the next profile; active smux streams do not migrate, but new connections can recover on the next profile. + +When `debug: true` is enabled, the CLI also emits a compact supervisor status +snapshot with the active profile, per-profile start/failure counters, and +bounded failover history size. diff --git a/docs/project-map.md b/docs/project-map.md index 55fd291..0b09cc3 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -305,13 +305,14 @@ Implemented: - `failover.max_cycles`. - Profile start/end logs. - Planned session rotation with `lifecycle.max_session_duration`. +- Shared supervisor status snapshots with bounded failover history. Still valuable: - Health scoring per profile. - Control-stream coordination before switching. - Stream draining and migration instead of dropping active smux streams. -- Shared status output for the active profile and failover history. +- User-facing status endpoint/export for the active profile and failover history. Likely files: diff --git a/internal/supervisor/supervisor.go b/internal/supervisor/supervisor.go index 929fed6..293a4eb 100644 --- a/internal/supervisor/supervisor.go +++ b/internal/supervisor/supervisor.go @@ -11,6 +11,14 @@ import ( ) const DefaultRetryDelay = 2 * time.Second +const DefaultHistoryLimit = 20 + +const ( + // EventProfileStart marks a profile attempt starting. + EventProfileStart = "profile_start" + // EventProfileEnd marks a profile attempt ending. + EventProfileEnd = "profile_end" +) var ( // ErrNoProfiles is returned when the supervisor is started without profiles. @@ -25,6 +33,36 @@ type Profile struct { Config session.Config } +// ProfileStatus summarizes one profile's failover history. +type ProfileStatus struct { + Name string + Starts int + Failures int + CleanEnds int + LastStarted time.Time + LastEnded time.Time + LastError string +} + +// Event is one bounded failover history entry. +type Event struct { + Time time.Time + Type string + Profile string + Cycle int + Error string +} + +// Status is a point-in-time view of the supervisor. +type Status struct { + Cycle int + ActiveProfile string + ActiveProfileIndex int + Profiles []ProfileStatus + History []Event + LastError string +} + // Runner starts one session profile and blocks until it ends or fails. type Runner func(ctx context.Context, cfg session.Config) error @@ -36,6 +74,8 @@ type Config struct { OnProfileStart func(profile Profile, cycle int) OnProfileEnd func(profile Profile, cycle int, err error) + OnStatus func(status Status) + HistoryLimit int } // Run starts profiles in order. If a profile exits while ctx is still active, @@ -47,6 +87,7 @@ func Run(ctx context.Context, cfg Config, run Runner) error { if cfg.RetryDelay == 0 { cfg.RetryDelay = DefaultRetryDelay } + state := newStatusTracker(cfg.Profiles, cfg.HistoryLimit, cfg.OnStatus) var lastErr error for cycle := 1; ; cycle++ { @@ -54,6 +95,7 @@ func Run(ctx context.Context, cfg Config, run Runner) error { if ctx.Err() != nil { return nil } + state.start(i, cycle) if cfg.OnProfileStart != nil { cfg.OnProfileStart(profile, cycle) } @@ -67,6 +109,7 @@ func Run(ctx context.Context, cfg Config, run Runner) error { } else { lastErr = fmt.Errorf("profile %q ended", profile.Name) } + state.end(i, cycle, err) if cfg.OnProfileEnd != nil { cfg.OnProfileEnd(profile, cycle, err) } @@ -81,6 +124,96 @@ func Run(ctx context.Context, cfg Config, run Runner) error { } } +type statusTracker struct { + status Status + notify func(Status) + historyLimit int +} + +func newStatusTracker(profiles []Profile, historyLimit int, notify func(Status)) *statusTracker { + if historyLimit == 0 { + historyLimit = DefaultHistoryLimit + } + statusProfiles := make([]ProfileStatus, 0, len(profiles)) + for _, profile := range profiles { + statusProfiles = append(statusProfiles, ProfileStatus{Name: profile.Name}) + } + return &statusTracker{ + status: Status{ + ActiveProfileIndex: -1, + Profiles: statusProfiles, + }, + notify: notify, + historyLimit: historyLimit, + } +} + +func (t *statusTracker) start(profileIndex, cycle int) { + now := time.Now() + profile := &t.status.Profiles[profileIndex] + profile.Starts++ + profile.LastStarted = now + t.status.Cycle = cycle + t.status.ActiveProfile = profile.Name + t.status.ActiveProfileIndex = profileIndex + t.appendHistory(Event{ + Time: now, + Type: EventProfileStart, + Profile: profile.Name, + Cycle: cycle, + }) + t.emit() +} + +func (t *statusTracker) end(profileIndex, cycle int, err error) { + now := time.Now() + profile := &t.status.Profiles[profileIndex] + profile.LastEnded = now + event := Event{ + Time: now, + Type: EventProfileEnd, + Profile: profile.Name, + Cycle: cycle, + } + if err != nil { + profile.Failures++ + profile.LastError = err.Error() + t.status.LastError = fmt.Sprintf("profile %q: %v", profile.Name, err) + event.Error = err.Error() + } else { + profile.CleanEnds++ + profile.LastError = "" + t.status.LastError = fmt.Sprintf("profile %q ended", profile.Name) + } + t.status.ActiveProfile = "" + t.status.ActiveProfileIndex = -1 + t.appendHistory(event) + t.emit() +} + +func (t *statusTracker) appendHistory(event Event) { + if t.historyLimit < 0 { + return + } + t.status.History = append(t.status.History, event) + if len(t.status.History) > t.historyLimit { + t.status.History = t.status.History[len(t.status.History)-t.historyLimit:] + } +} + +func (t *statusTracker) emit() { + if t.notify == nil { + return + } + t.notify(cloneStatus(t.status)) +} + +func cloneStatus(status Status) Status { + status.Profiles = append([]ProfileStatus(nil), status.Profiles...) + status.History = append([]Event(nil), status.History...) + return status +} + func waitRetryDelay(ctx context.Context, delay time.Duration) error { if delay <= 0 { return nil diff --git a/internal/supervisor/supervisor_test.go b/internal/supervisor/supervisor_test.go index aab0dee..253d310 100644 --- a/internal/supervisor/supervisor_test.go +++ b/internal/supervisor/supervisor_test.go @@ -58,6 +58,91 @@ func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) { } } +func TestRunEmitsStatusHistory(t *testing.T) { + profiles := []Profile{ + {Name: "first", Config: session.Config{Auth: "wbstream"}}, + {Name: "second", Config: session.Config{Auth: "jitsi"}}, + } + var snapshots []Status + err := Run(context.Background(), Config{ + Profiles: profiles, + RetryDelay: -1, + MaxCycles: 1, + HistoryLimit: 3, + OnStatus: func(status Status) { + snapshots = append(snapshots, status) + }, + }, func(_ context.Context, cfg session.Config) error { + if cfg.Auth == "first" { + t.Fatal("runner received profile name instead of config") + } + return errRunnerBoom + }) + if !errors.Is(err, ErrMaxCyclesExceeded) { + t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded) + } + if len(snapshots) != 4 { + t.Fatalf("status snapshots = %d, want 4", len(snapshots)) + } + first := snapshots[0] + if first.ActiveProfile != "first" || first.ActiveProfileIndex != 0 || first.Cycle != 1 { + t.Fatalf("first status = %+v", first) + } + if first.Profiles[0].Starts != 1 || first.Profiles[0].LastStarted.IsZero() { + t.Fatalf("first profile start status = %+v", first.Profiles[0]) + } + last := snapshots[len(snapshots)-1] + if last.ActiveProfile != "" || last.ActiveProfileIndex != -1 { + t.Fatalf("last active status = %+v", last) + } + if last.Profiles[0].Failures != 1 || last.Profiles[1].Failures != 1 { + t.Fatalf("profile failures = %+v", last.Profiles) + } + if last.LastError == "" || last.Profiles[1].LastError == "" { + t.Fatalf("last errors missing: %+v", last) + } + if len(last.History) != 3 { + t.Fatalf("history length = %d, want 3", len(last.History)) + } + if last.History[0].Type != EventProfileEnd || last.History[0].Profile != "first" { + t.Fatalf("oldest bounded history event = %+v", last.History[0]) + } + if last.History[2].Type != EventProfileEnd || last.History[2].Profile != "second" || + last.History[2].Error == "" { + t.Fatalf("last history event = %+v", last.History[2]) + } +} + +func TestRunStatusSnapshotIsImmutable(t *testing.T) { + var first Status + var second Status + err := Run(context.Background(), Config{ + Profiles: []Profile{{Name: "one"}}, + RetryDelay: -1, + MaxCycles: 1, + OnStatus: func(status Status) { + if first.Profiles == nil { + first = status + first.Profiles[0].Starts = 99 + first.History[0].Profile = "mutated" + return + } + second = status + }, + }, func(context.Context, session.Config) error { + return errRunnerBoom + }) + if !errors.Is(err, ErrMaxCyclesExceeded) { + t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded) + } + if first.Profiles[0].Starts != 99 || first.History[0].Profile != "mutated" { + t.Fatalf("test mutation did not apply to snapshot: %+v", first) + } + if second.Profiles[0].Starts != 1 || second.History[0].Profile != "one" { + t.Fatalf("snapshot mutation leaked into later status: %+v", second) + } +} + func TestRunReturnsNilOnContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) err := Run(ctx, Config{ From b7a7e4089979fe767b9addd2f6fbc32331cf0603 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 01:17:07 +0300 Subject: [PATCH 7/8] feat: add safe traffic shaping and TLS hardening --- docs/client.example.yaml | 6 + docs/configuration.md | 23 ++++ docs/failover.example.yaml | 6 + docs/project-map.md | 5 + docs/server.example.yaml | 6 + docs/settings.md | 9 ++ internal/app/session/session.go | 148 +++++++++++++++++------ internal/app/session/session_test.go | 57 +++++++++ internal/auth/salutejazz/api.go | 15 +-- internal/auth/telemost/api.go | 4 +- internal/auth/wbstream/api.go | 13 +- internal/client/client.go | 23 +++- internal/client/client_test.go | 5 + internal/config/config.go | 15 +++ internal/config/config_test.go | 56 ++++++--- internal/crypto/chacha.go | 3 + internal/engine/goolom/lifecycle.go | 5 +- internal/engine/salutejazz/salutejazz.go | 5 +- internal/link/direct/direct.go | 4 + internal/link/direct/direct_test.go | 7 +- internal/link/link.go | 17 ++- internal/protect/protect.go | 93 ++++++++++++-- internal/protect/protect_test.go | 50 +++++++- internal/server/server.go | 23 +++- internal/server/server_test.go | 5 + internal/transport/traffic.go | 91 ++++++++++++++ internal/transport/traffic_test.go | 67 ++++++++++ internal/transport/transport.go | 19 ++- 28 files changed, 662 insertions(+), 118 deletions(-) create mode 100644 internal/transport/traffic.go create mode 100644 internal/transport/traffic_test.go diff --git a/docs/client.example.yaml b/docs/client.example.yaml index 06b9b5e..c29fae5 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -30,6 +30,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + # Local SOCKS5 listener exposed to applications socks: host: "127.0.0.1" diff --git a/docs/configuration.md b/docs/configuration.md index 52123f1..07d1713 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,6 +35,8 @@ olcrtc /etc/olcrtc/server.yaml | `liveness.timeout` | pong timeout, default `5s` | | `liveness.failures` | missed pongs before reconnect, default `3` | | `lifecycle.max_session_duration` | planned session rebuild interval, e.g. `6h`; unset = off | +| `traffic.max_payload_size` | safe encrypted wire-message cap; `0` = transport default | +| `traffic.min_delay` / `.max_delay` | optional send pacing jitter, e.g. `5ms` / `30ms` | | `gen.amount` | gen mode: number of rooms to create | | `profiles[]` | ordered srv/cnc failover profiles | | `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | @@ -86,6 +88,27 @@ lifecycle: The field is optional and disabled when omitted. Values use Go duration syntax such as `30m`, `2h`, or `6h`; zero and negative durations are rejected. +## Traffic Shaping + +`traffic` applies a shared reliability-oriented wrapper around the selected +transport. It can cap encrypted wire-message size and add small send pacing +delays without truncating data. When a payload would exceed the effective cap, +the send fails clearly instead of cutting bytes and corrupting smux. + +```yaml +traffic: + max_payload_size: 4096 + min_delay: 5ms + max_delay: 30ms +``` + +The wrapper clamps the configured payload cap to the selected transport's +advertised `MaxPayloadSize`. Client and server also reduce smux frame size to +fit the effective encrypted payload cap, accounting for crypto overhead. `0` +adds no extra cap beyond the selected transport's advertised limit. Delays use +Go duration syntax; if only `min_delay` is set, it is a fixed delay. Use the +same traffic settings on both peers. + ## Failover Profiles `mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml index 298a847..bf42482 100644 --- a/docs/failover.example.yaml +++ b/docs/failover.example.yaml @@ -19,6 +19,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + data: data profiles: diff --git a/docs/project-map.md b/docs/project-map.md index 0b09cc3..d0ebd41 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -73,6 +73,8 @@ Important fields: | `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | | `engine.*` | direct engine fields | Used only with `auth.provider: none`. | | `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. | +| `lifecycle.*` | session lifecycle | Planned call/session rotation. | +| `traffic.*` | send shaping | Encrypted wire-message size cap and optional pacing jitter. | `internal/app/session` is the main router: @@ -306,6 +308,7 @@ Implemented: - Profile start/end logs. - Planned session rotation with `lifecycle.max_session_duration`. - Shared supervisor status snapshots with bounded failover history. +- Shared traffic wrapper with payload cap, pacing jitter, and smux frame sizing. Still valuable: @@ -371,6 +374,8 @@ This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`. Provider APIs can drift. Worth adding: +- Central protected HTTP/WebSocket client creation with TLS 1.2+, + environment proxy support, HTTP/2 for HTTP, and bounded timeouts. - Better typed errors from auth providers. - Provider health probes. - Fixture-based contract tests for API response changes. diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 300f7cf..112ce42 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -32,6 +32,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + # Outbound SOCKS5 proxy for server-side egress (optional) socks: proxy_addr: "" # e.g. "127.0.0.1" diff --git a/docs/settings.md b/docs/settings.md index 9f9d215..b3bf159 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -67,6 +67,8 @@ | `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` | | `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` | | `lifecycle.max_session_duration` | Плановый rebuild сессии после указанного времени, например `6h`; если поле не задано, выключено | +| `traffic.max_payload_size` | Лимит размера зашифрованного wire-message; `0` = лимит транспорта | +| `traffic.min_delay` / `.max_delay` | Необязательный pacing отправки, например `5ms` / `30ms` | `crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. @@ -86,6 +88,13 @@ provider session. Когда таймер истекает, текущая `srv` плановый rebuild. Формат значения: `30m`, `2h`, `6h`; `0s` и отрицательные значения не принимаются. +`traffic` добавляет общий wrapper над выбранным transport. Он может ограничить +размер зашифрованного сообщения и добавить небольшую задержку перед отправкой. +Данные не обрезаются: если сообщение не помещается в эффективный лимит, send +возвращает явную ошибку. При заданном `max_payload_size` smux frame size также +уменьшается с учетом crypto overhead; при `0` остается лимит выбранного +transport. Используй одинаковые traffic-настройки на обеих сторонах. + --- ## mode: gen diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 0b48f50..8df7b65 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -15,6 +15,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/link/direct" "github.com/openlibrecommunity/olcrtc/internal/logger" @@ -137,47 +138,59 @@ var ( // ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration. ErrLifecycleMaxSessionDurationInvalid = errors.New( "invalid max session duration (set lifecycle.max_session_duration to a duration > 0)") + // ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid. + ErrTrafficMaxPayloadSizeInvalid = errors.New( + "invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)") + // ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration. + ErrTrafficMinDelayInvalid = errors.New( + "invalid traffic min delay (set traffic.min_delay to a duration >= 0)") + // ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration. + ErrTrafficMaxDelayInvalid = errors.New( + "invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)") ) // Config holds runtime session settings. type Config struct { - Mode string - Link string - Transport string - Auth string - Engine string - URL string - Token string - RoomID string - KeyHex string - SOCKSHost string - SOCKSPort int - SOCKSUser string - SOCKSPass string - DNSServer string - SOCKSProxyAddr string - SOCKSProxyPort int - VideoWidth int - VideoHeight int - VideoFPS int - VideoBitrate string - VideoHW string - VideoQRSize int - VideoQRRecovery string - VideoCodec string - VideoTileModule int - VideoTileRS int - VP8FPS int - VP8BatchSize int - SEIFPS int - SEIBatchSize int - SEIFragmentSize int - SEIAckTimeoutMS int - LivenessInterval string - LivenessTimeout string - LivenessFailures int - MaxSessionDuration string - Amount int + Mode string + Link string + Transport string + Auth string + Engine string + URL string + Token string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + SOCKSUser string + SOCKSPass string + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + LivenessInterval string + LivenessTimeout string + LivenessFailures int + MaxSessionDuration string + TrafficMaxPayloadSize int + TrafficMinDelay string + TrafficMaxDelay string + Amount int } // RegisterDefaults registers built-in carriers and transports. @@ -333,6 +346,9 @@ func Validate(cfg Config) error { if err := validateLifecycleConfig(cfg); err != nil { return err } + if err := validateTrafficConfig(cfg); err != nil { + return err + } return validateModeConfig(cfg) } @@ -539,6 +555,48 @@ func maxSessionDuration(cfg Config) (time.Duration, error) { return d, nil } +func validateTrafficConfig(cfg Config) error { + _, err := trafficConfig(cfg) + return err +} + +func trafficConfig(cfg Config) (transport.TrafficConfig, error) { + if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 && + cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) { + return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid + } + minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay) + if err != nil { + return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err) + } + maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay) + if err != nil { + return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err) + } + if maxDelay > 0 && maxDelay < minDelay { + return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid + } + return transport.TrafficConfig{ + MaxPayloadSize: cfg.TrafficMaxPayloadSize, + MinDelay: minDelay, + MaxDelay: maxDelay, + }, nil +} + +func parseOptionalNonNegativeDuration(value string) (time.Duration, error) { + if value == "" { + return 0, nil + } + d, err := time.ParseDuration(value) + if err != nil { + return 0, err + } + if d < 0 { + return 0, fmt.Errorf("duration must be >= 0") + } + return d, nil +} + func isLoopbackListenHost(host string) bool { if host == "localhost" { return true @@ -560,9 +618,13 @@ func Run(ctx context.Context, cfg Config) error { if err != nil { return err } + traffic, err := trafficConfig(cfg) + if err != nil { + return err + } run := func(ctx context.Context) error { - return runOnce(ctx, cfg, roomURL, liveness) + return runOnce(ctx, cfg, roomURL, liveness, traffic) } if maxDuration > 0 { return runWithSessionRotation(ctx, maxDuration, run) @@ -570,7 +632,13 @@ func Run(ctx context.Context, cfg Config) error { return run(ctx) } -func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.Config) error { +func runOnce( + ctx context.Context, + cfg Config, + roomURL string, + liveness control.Config, + traffic transport.TrafficConfig, +) error { switch cfg.Mode { case modeSRV: if err := server.Run(ctx, server.Config{ @@ -602,6 +670,7 @@ func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.C URL: cfg.URL, Token: cfg.Token, Liveness: liveness, + Traffic: traffic, OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) { logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims) }, @@ -646,6 +715,7 @@ func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.C URL: cfg.URL, Token: cfg.Token, Liveness: liveness, + Traffic: traffic, }); err != nil { return fmt.Errorf("client: %w", err) } diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index 5fc219d..d75371b 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/crypto" ) func TestApplyTransportDefaults(t *testing.T) { @@ -522,6 +523,62 @@ func TestValidate(t *testing.T) { }(), want: ErrLifecycleMaxSessionDurationInvalid, }, + { + name: "traffic accepts shaping", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = 4096 + cfg.TrafficMinDelay = "5ms" + cfg.TrafficMaxDelay = "30ms" + return cfg + }(), + }, + { + name: "traffic rejects negative max payload", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = -1 + return cfg + }(), + want: ErrTrafficMaxPayloadSizeInvalid, + }, + { + name: "traffic rejects payload smaller than crypto overhead", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = crypto.WireOverhead + return cfg + }(), + want: ErrTrafficMaxPayloadSizeInvalid, + }, + { + name: "traffic rejects bad min delay", + cfg: func() Config { + cfg := base + cfg.TrafficMinDelay = "nope" + return cfg + }(), + want: ErrTrafficMinDelayInvalid, + }, + { + name: "traffic rejects negative max delay", + cfg: func() Config { + cfg := base + cfg.TrafficMaxDelay = "-1ms" + return cfg + }(), + want: ErrTrafficMaxDelayInvalid, + }, + { + name: "traffic rejects max delay below min delay", + cfg: func() Config { + cfg := base + cfg.TrafficMinDelay = "30ms" + cfg.TrafficMaxDelay = "5ms" + return cfg + }(), + want: ErrTrafficMaxDelayInvalid, + }, } for _, tt := range tests { diff --git a/internal/auth/salutejazz/api.go b/internal/auth/salutejazz/api.go index 594ac5c..40cd092 100644 --- a/internal/auth/salutejazz/api.go +++ b/internal/auth/salutejazz/api.go @@ -9,9 +9,7 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" - "strings" "github.com/google/uuid" "github.com/openlibrecommunity/olcrtc/internal/protect" @@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, statusError(errCreateRoomFailed, resp) + return nil, protect.StatusError(errCreateRoomFailed, resp, 1024) } var res createResponse @@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string defer func() { _ = preResp.Body.Close() }() if preResp.StatusCode != http.StatusOK { - return "", statusError(errPreconnectFailed, preResp) + return "", protect.StatusError(errPreconnectFailed, preResp, 1024) } var preconnectResp struct { @@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string return preconnectResp.ConnectorURL, nil } -func statusError(base error, resp *http.Response) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - bodyText := strings.TrimSpace(string(body)) - if bodyText == "" { - return fmt.Errorf("%w: status %d", base, resp.StatusCode) - } - return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText) -} - func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) { headers := anonymousHeaders() connectorURL, err := preconnect(ctx, roomID, password, headers) diff --git a/internal/auth/telemost/api.go b/internal/auth/telemost/api.go index cde00f0..a9b1116 100644 --- a/internal/auth/telemost/api.go +++ b/internal/auth/telemost/api.go @@ -11,7 +11,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" @@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body) + return nil, protect.StatusError(ErrAPI, resp, 4096) } var info ConnectionInfo diff --git a/internal/auth/wbstream/api.go b/internal/auth/wbstream/api.go index 4fc277b..ea1a927 100644 --- a/internal/auth/wbstream/api.go +++ b/internal/auth/wbstream/api.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "github.com/openlibrecommunity/olcrtc/internal/protect" @@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b) + return "", protect.StatusError(errGuestRegister, resp, 4096) } var res guestRegisterResponse @@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - b, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b) + return "", protect.StatusError(errCreateRoom, resp, 4096) } var res createRoomResponse @@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b) + return protect.StatusError(errJoinRoom, resp, 4096) } return nil } @@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b) + return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096) } var res tokenResponse diff --git a/internal/client/client.go b/internal/client/client.go index 001cb4c..2dfc153 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -103,6 +104,7 @@ type Config struct { URL string Token string Liveness control.Config + Traffic transport.TrafficConfig // DeviceID overrides the persistent client-side device identifier. Leave // empty to derive one from DeviceIDPath (or generate a random one if both @@ -216,6 +218,7 @@ func (c *Client) bringUpLink( SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -241,7 +244,7 @@ func (c *Client) bringUpLink( } c.conn = muxconn.New(ln, c.cipher) - sess, err := smux.Client(c.conn, smuxConfig()) + sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln))) if err != nil { return fmt.Errorf("smux client: %w", err) } @@ -332,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) { } // smuxConfig returns the tuned smux config used on both ends. -func smuxConfig() *smux.Config { +func smuxConfig(maxWirePayload ...int) *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 cfg.KeepAliveDisabled = true cfg.MaxFrameSize = 32768 + if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { + maxFrameSize := maxWirePayload[0] - crypto.WireOverhead + if maxFrameSize < cfg.MaxFrameSize { + cfg.MaxFrameSize = maxFrameSize + } + } cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second @@ -344,6 +353,14 @@ func smuxConfig() *smux.Config { return cfg } +func linkMaxPayload(ln link.Link) int { + provider, ok := ln.(link.FeaturesProvider) + if !ok { + return 0 + } + return provider.Features().MaxPayloadSize +} + func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool { c.reconnectMu.Lock() defer c.reconnectMu.Unlock() @@ -421,7 +438,7 @@ func (c *Client) tryReopenSession( _ = old.Close() } - sess, err := smux.Client(conn, smuxConfig()) + sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln))) if err != nil { logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err) return false diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 82d0099..40b3c22 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -49,6 +49,11 @@ func TestSmuxConfig(t *testing.T) { if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig() = %+v", cfg) } + capped := smuxConfig(4096) + if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { + t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d", + capped.MaxFrameSize, 4096-cryptopkg.WireOverhead) + } } func TestSocks5Handshake(t *testing.T) { diff --git a/internal/config/config.go b/internal/config/config.go index 770adf5..3cd5a0a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,6 +43,7 @@ type File struct { SEI SEI `yaml:"sei"` Liveness Liveness `yaml:"liveness"` Lifecycle Lifecycle `yaml:"lifecycle"` + Traffic Traffic `yaml:"traffic"` Gen Gen `yaml:"gen"` Profiles []Profile `yaml:"profiles"` Failover Failover `yaml:"failover"` @@ -66,6 +67,7 @@ type Profile struct { SEI SEI `yaml:"sei"` Liveness Liveness `yaml:"liveness"` Lifecycle Lifecycle `yaml:"lifecycle"` + Traffic Traffic `yaml:"traffic"` } // Failover controls ordered profile failover. @@ -153,6 +155,13 @@ type Lifecycle struct { MaxSessionDuration string `yaml:"max_session_duration"` } +// Traffic controls optional reliability-oriented send shaping. +type Traffic struct { + MaxPayloadSize int `yaml:"max_payload_size"` + MinDelay string `yaml:"min_delay"` + MaxDelay string `yaml:"max_delay"` +} + // Gen controls room-generation mode. type Gen struct { Amount int `yaml:"amount"` @@ -268,6 +277,9 @@ func Apply(dst session.Config, f File) session.Config { dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout) dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures) dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration) + dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize) + dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay) + dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay) dst.Amount = pickInt(dst.Amount, f.Gen.Amount) return dst } @@ -310,6 +322,9 @@ func ApplyProfile(base session.Config, p Profile) session.Config { dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout) dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures) dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration) + dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize) + dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay) + dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay) return dst } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 06d1406..c699283 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -45,6 +45,10 @@ liveness: failures: 4 lifecycle: max_session_duration: 6h +traffic: + max_payload_size: 4096 + min_delay: 5ms + max_delay: 30ms gen: amount: 3 debug: true @@ -82,24 +86,27 @@ func requireLoadedFile(t *testing.T, f File) { func requireAppliedConfig(t *testing.T, got session.Config) { t.Helper() want := session.Config{ - Mode: testModeSrv, - Link: "direct", - Auth: testAuthProvider, - RoomID: testRoomID, - KeyHex: testCryptoKey, - Transport: "datachannel", - DNSServer: "1.1.1.1:53", - SOCKSHost: "127.0.0.1", - SOCKSPort: 1080, - SOCKSUser: "u", - SOCKSPass: "p", - VP8FPS: 25, - VP8BatchSize: 4, - LivenessInterval: "2s", - LivenessTimeout: "500ms", - LivenessFailures: 4, - MaxSessionDuration: "6h", - Amount: 3, + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + LivenessInterval: "2s", + LivenessTimeout: "500ms", + LivenessFailures: 4, + MaxSessionDuration: "6h", + TrafficMaxPayloadSize: 4096, + TrafficMinDelay: "5ms", + TrafficMaxDelay: "30ms", + Amount: 3, } if got != want { t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) @@ -148,6 +155,10 @@ liveness: failures: 5 lifecycle: max_session_duration: 6h +traffic: + max_payload_size: 8192 + min_delay: 10ms + max_delay: 40ms profiles: - name: wb-vp8 auth: @@ -162,6 +173,9 @@ profiles: interval: 1s lifecycle: max_session_duration: 30m + traffic: + max_payload_size: 4096 + max_delay: 20ms - name: jitsi-dc auth: provider: jitsi @@ -196,7 +210,8 @@ failover: } if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 || first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 || - first.MaxSessionDuration != "30m" { + first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 || + first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" { t.Fatalf("first inherited/overlaid fields = %+v", first) } second := ApplyProfile(base, f.Profiles[1]) @@ -205,7 +220,8 @@ failover: t.Fatalf("second profile = %+v", second) } if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 || - second.MaxSessionDuration != "6h" { + second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 || + second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" { t.Fatalf("second lifecycle/liveness fields = %+v", second) } } diff --git a/internal/crypto/chacha.go b/internal/crypto/chacha.go index 686d8b8..93a8425 100644 --- a/internal/crypto/chacha.go +++ b/internal/crypto/chacha.go @@ -10,6 +10,9 @@ import ( "golang.org/x/crypto/chacha20poly1305" ) +// WireOverhead is the number of bytes added to each encrypted message. +const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead + var ( // ErrInvalidKeySize is returned when the encryption key is not 32 bytes. ErrInvalidKeySize = errors.New("invalid key size") diff --git a/internal/engine/goolom/lifecycle.go b/internal/engine/goolom/lifecycle.go index 316107f..7dd803d 100644 --- a/internal/engine/goolom/lifecycle.go +++ b/internal/engine/goolom/lifecycle.go @@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error { } func (s *Session) dialWebSocket() error { - wsDialer := websocket.Dialer{ - NetDialContext: protect.DialContext, - HandshakeTimeout: wsHandshakeTimeout, - } + wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout) ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil) if err != nil { return fmt.Errorf("dial ws: %w", err) diff --git a/internal/engine/salutejazz/salutejazz.go b/internal/engine/salutejazz/salutejazz.go index 5daf47f..b1b8903 100644 --- a/internal/engine/salutejazz/salutejazz.go +++ b/internal/engine/salutejazz/salutejazz.go @@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration) } func (s *Session) dialWebSocket() error { - wsDialer := websocket.Dialer{ - NetDialContext: protect.DialContext, - HandshakeTimeout: wsHandshakeTimeout, - } + wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout) ws, resp, err := wsDialer.Dial(s.connectorURL, nil) if err != nil { diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go index 4b2aa73..65089ab 100644 --- a/internal/link/direct/direct.go +++ b/internal/link/direct/direct.go @@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) { SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return nil, fmt.Errorf("create transport for direct link: %w", err) @@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) { d.transport.WatchConnection(ctx) } func (d *directLink) CanSend() bool { return d.transport.CanSend() } + +// Features reports the direct link's underlying transport capabilities. +func (d *directLink) Features() link.Features { return d.transport.Features() } diff --git a/internal/link/direct/direct_test.go b/internal/link/direct/direct_test.go index 18edd2e..f891e88 100644 --- a/internal/link/direct/direct_test.go +++ b/internal/link/direct/direct_test.go @@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { VideoTileRS: 20, VP8FPS: 25, VP8BatchSize: 8, + Traffic: transport.TrafficConfig{MaxPayloadSize: 4096}, }) if err != nil { t.Fatalf("New() error = %v", err) } - if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 { + if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 || + seen.Traffic.MaxPayloadSize != 4096 { t.Fatalf("forwarded config = %+v", seen) } @@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { if !ln.CanSend() { t.Fatal("CanSend() = false, want true") } + if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 { + t.Fatalf("Features() = %+v, want shaped max payload 4096", features) + } } func TestNewWrapsFactoryError(t *testing.T) { diff --git a/internal/link/link.go b/internal/link/link.go index f094cd0..c8957ac 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -4,6 +4,8 @@ package link import ( "context" "errors" + + "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -23,11 +25,19 @@ type Link interface { CanSend() bool } +// Features mirrors the underlying transport capabilities when a link can expose them. +type Features = transport.Features + +// FeaturesProvider is optionally implemented by links that can report wire limits. +type FeaturesProvider interface { + Features() Features +} + // Config holds common link configuration. type Config struct { - Transport string - Carrier string - RoomURL string + Transport string + Carrier string + RoomURL string // Engine, URL, Token are forwarded for the "none" auth carrier. Engine string URL string @@ -54,6 +64,7 @@ type Config struct { SEIBatchSize int SEIFragmentSize int SEIAckTimeoutMS int + Traffic transport.TrafficConfig } // Factory creates a link instance. diff --git a/internal/protect/protect.go b/internal/protect/protect.go index 29bc277..2919fa3 100644 --- a/internal/protect/protect.go +++ b/internal/protect/protect.go @@ -3,13 +3,38 @@ package protect import ( "context" + "crypto/tls" "fmt" + "io" "net" "net/http" + "regexp" + "strings" "syscall" "time" + + "github.com/gorilla/websocket" ) +const ( + defaultDialTimeout = 10 * time.Second + defaultKeepAlive = 30 * time.Second + defaultIdleConnTimeout = 30 * time.Second + defaultTLSHandshake = 10 * time.Second + defaultResponseHeader = 10 * time.Second + defaultWebSocketTimeout = 10 * time.Second + defaultHTTPClientTimeout = 30 * time.Second + defaultStatusBodyLimit = 1024 +) + +var ( + sensitiveFieldRE = regexp.MustCompile( + `(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` + + `[^",\s}]+`, + ) + sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`) +) //nolint:gochecknoglobals // compiled once for provider error redaction + // Protector is called with a socket file descriptor before connect. // On Android, this calls VpnService.protect(fd) to bypass VPN routing. var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional @@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error { // NewDialer returns a net.Dialer that calls Protector on each new socket. func NewDialer() *net.Dialer { return &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, Control: controlFunc, } } +// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients. +func NewTLSConfig() *tls.Config { + return &tls.Config{MinVersion: tls.VersionTLS12} +} + +// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts. +func NewHTTPTransport() *http.Transport { + dialer := NewDialer() + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + TLSClientConfig: NewTLSConfig(), + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: defaultIdleConnTimeout, + TLSHandshakeTimeout: defaultTLSHandshake, + ResponseHeaderTimeout: defaultResponseHeader, + } +} + // NewHTTPClient returns an http.Client using protected sockets. func NewHTTPClient() *http.Client { - dialer := NewDialer() - transport := &http.Transport{ - DialContext: dialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, + return &http.Client{ + Transport: NewHTTPTransport(), + Timeout: defaultHTTPClientTimeout, } - return &http.Client{Transport: transport} +} + +// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy. +func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer { + if handshakeTimeout <= 0 { + handshakeTimeout = defaultWebSocketTimeout + } + return websocket.Dialer{ + NetDialContext: DialContext, + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: NewTLSConfig(), + HandshakeTimeout: handshakeTimeout, + } +} + +// StatusError formats an upstream HTTP error while bounding and redacting the body. +func StatusError(base error, resp *http.Response, limit int64) error { + if limit <= 0 { + limit = defaultStatusBodyLimit + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, limit)) + bodyText := RedactSensitive(strings.TrimSpace(string(body))) + if bodyText == "" { + return fmt.Errorf("%w: status %d", base, resp.StatusCode) + } + return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText) +} + +// RedactSensitive removes common token-like values from provider error text. +func RedactSensitive(text string) string { + text = sensitiveBearerRE.ReplaceAllString(text, "${1}") + return sensitiveFieldRE.ReplaceAllString(text, "${1}") } // DialContext dials using a protected socket. diff --git a/internal/protect/protect_test.go b/internal/protect/protect_test.go index 515f82d..e07a666 100644 --- a/internal/protect/protect_test.go +++ b/internal/protect/protect_test.go @@ -2,9 +2,11 @@ package protect import ( "context" + "crypto/tls" "errors" "net" "net/http" + "strings" "syscall" "testing" "time" @@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) { if !ok { t.Fatalf("Transport type = %T, want *http.Transport", client.Transport) } - if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 || + if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil || + tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 || tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second || - tr.ResponseHeaderTimeout != 10*time.Second { + tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second { t.Fatalf("transport = %+v", tr) } } +func TestNewWebSocketDialer(t *testing.T) { + dialer := NewWebSocketDialer(3 * time.Second) + if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil || + dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 || + dialer.HandshakeTimeout != 3*time.Second { + t.Fatalf("NewWebSocketDialer() = %+v", dialer) + } + + defaulted := NewWebSocketDialer(0) + if defaulted.HandshakeTimeout != defaultWebSocketTimeout { + t.Fatalf("default HandshakeTimeout = %v, want %v", + defaulted.HandshakeTimeout, defaultWebSocketTimeout) + } +} + +func TestStatusErrorRedactsAndLimitsBody(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)}, + } + err := StatusError(errProtectBoom, resp, 1024) + if err == nil { + t.Fatal("StatusError() error = nil") + } + text := err.Error() + if strings.Contains(text, "secret") || !strings.Contains(text, "") { + t.Fatalf("StatusError() = %q, want redacted token", text) + } +} + +func TestRedactSensitiveBearer(t *testing.T) { + got := RedactSensitive("Authorization: Bearer abc.def") + if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer ") { + t.Fatalf("RedactSensitive() = %q", got) + } +} + +type ioNopCloser struct { + *strings.Reader +} + +func (c ioNopCloser) Close() error { return nil } + func TestDialContextAndProxyDialer(t *testing.T) { var lc net.ListenConfig ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0") diff --git a/internal/server/server.go b/internal/server/server.go index 7dae4eb..2c28805 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,6 +21,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -116,6 +117,7 @@ type Config struct { URL string Token string Liveness control.Config + Traffic transport.TrafficConfig // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. @@ -234,11 +236,17 @@ func (s *Server) setupResolver() { // smuxConfig mirrors the client side. Both peers must agree on Version and // MaxFrameSize. -func smuxConfig() *smux.Config { +func smuxConfig(maxWirePayload ...int) *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 cfg.KeepAliveDisabled = true cfg.MaxFrameSize = 32768 + if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { + maxFrameSize := maxWirePayload[0] - crypto.WireOverhead + if maxFrameSize < cfg.MaxFrameSize { + cfg.MaxFrameSize = maxFrameSize + } + } cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second @@ -246,6 +254,14 @@ func smuxConfig() *smux.Config { return cfg } +func linkMaxPayload(ln link.Link) int { + provider, ok := ln.(link.FeaturesProvider) + if !ok { + return 0 + } + return provider.Features().MaxPayloadSize +} + func (s *Server) bringUpLink( ctx context.Context, cfg Config, @@ -280,6 +296,7 @@ func (s *Server) bringUpLink( SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -316,7 +333,7 @@ func (s *Server) bringUpLink( func (s *Server) installSession() { conn := muxconn.New(s.ln, s.cipher) - sess, err := smux.Server(conn, smuxConfig()) + sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) return @@ -342,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) { // Pre-build the replacement so we can swap atomically below. newConn := muxconn.New(s.ln, s.cipher) - newSess, err := smux.Server(newConn, smuxConfig()) + newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) _ = newConn.Close() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index dc80b21..65a2bc5 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -50,6 +50,11 @@ func TestSmuxConfig(t *testing.T) { if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig() = %+v", cfg) } + capped := smuxConfig(4096) + if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { + t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d", + capped.MaxFrameSize, 4096-cryptopkg.WireOverhead) + } } func TestParseConnectRequest(t *testing.T) { diff --git a/internal/transport/traffic.go b/internal/transport/traffic.go new file mode 100644 index 0000000..31f194b --- /dev/null +++ b/internal/transport/traffic.go @@ -0,0 +1,91 @@ +package transport + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "sync" + "time" +) + +var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size") + +type trafficTransport struct { + inner Transport + maxPayloadSize int + minDelay time.Duration + maxDelay time.Duration + sendMu sync.Mutex +} + +// WithTraffic wraps tr with optional payload caps and send pacing. +func WithTraffic(tr Transport, cfg TrafficConfig) Transport { + if tr == nil { + return nil + } + cfg = effectiveTrafficConfig(tr.Features(), cfg) + if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 { + return tr + } + return &trafficTransport{ + inner: tr, + maxPayloadSize: cfg.MaxPayloadSize, + minDelay: cfg.MinDelay, + maxDelay: cfg.MaxDelay, + } +} + +func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig { + if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize { + cfg.MaxPayloadSize = features.MaxPayloadSize + } + return cfg +} + +func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) } + +func (t *trafficTransport) Send(data []byte) error { + t.sendMu.Lock() + defer t.sendMu.Unlock() + if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize { + return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize) + } + if delay := t.nextDelay(); delay > 0 { + time.Sleep(delay) + } + return t.inner.Send(data) +} + +func (t *trafficTransport) Close() error { return t.inner.Close() } + +func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) } + +func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) } + +func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) } + +func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) } + +func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() } + +func (t *trafficTransport) Features() Features { + features := t.inner.Features() + if t.maxPayloadSize > 0 && + (features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) { + features.MaxPayloadSize = t.maxPayloadSize + } + return features +} + +func (t *trafficTransport) nextDelay() time.Duration { + if t.maxDelay <= 0 && t.minDelay <= 0 { + return 0 + } + minDelay := t.minDelay + maxDelay := t.maxDelay + if maxDelay <= minDelay { + return minDelay + } + return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter +} diff --git a/internal/transport/traffic_test.go b/internal/transport/traffic_test.go new file mode 100644 index 0000000..9f6139a --- /dev/null +++ b/internal/transport/traffic_test.go @@ -0,0 +1,67 @@ +package transport + +import ( + "context" + "errors" + "testing" + "time" +) + +type trafficStubTransport struct { + features Features + sent [][]byte +} + +func (s *trafficStubTransport) Connect(context.Context) error { return nil } +func (s *trafficStubTransport) Send(data []byte) error { + s.sent = append(s.sent, append([]byte(nil), data...)) + return nil +} +func (s *trafficStubTransport) Close() error { return nil } +func (s *trafficStubTransport) SetReconnectCallback(func()) {} +func (s *trafficStubTransport) SetShouldReconnect(func() bool) {} +func (s *trafficStubTransport) SetEndedCallback(func(string)) {} +func (s *trafficStubTransport) WatchConnection(context.Context) {} +func (s *trafficStubTransport) CanSend() bool { return true } +func (s *trafficStubTransport) Features() Features { return s.features } + +func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) { + inner := &trafficStubTransport{} + got := WithTraffic(inner, TrafficConfig{}) + if got != inner { + t.Fatalf("WithTraffic disabled returned %T, want inner", got) + } +} + +func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) { + inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}} + tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10}) + if features := tr.Features(); features.MaxPayloadSize != 5 { + t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize) + } + err := tr.Send([]byte("123456")) + if !errors.Is(err, ErrTrafficPayloadTooLarge) { + t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge) + } + if len(inner.sent) != 0 { + t.Fatalf("inner sent %d payloads, want 0", len(inner.sent)) + } + if err := tr.Send([]byte("12345")); err != nil { + t.Fatalf("Send(max sized) error = %v", err) + } + if got := string(inner.sent[0]); got != "12345" { + t.Fatalf("inner payload = %q, want 12345", got) + } +} + +func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) { + inner := &trafficStubTransport{} + tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond}) + start := time.Now() + if err := tr.Send([]byte("x")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if elapsed := time.Since(start); elapsed < 2*time.Millisecond { + t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed) + } +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 9e11240..2f37a41 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -4,6 +4,7 @@ package transport import ( "context" "errors" + "time" ) var ( @@ -32,10 +33,17 @@ type Transport interface { Features() Features } +// TrafficConfig controls optional reliability-oriented send shaping. +type TrafficConfig struct { + MaxPayloadSize int + MinDelay time.Duration + MaxDelay time.Duration +} + // Config holds common transport configuration. type Config struct { - Carrier string - RoomURL string + Carrier string + RoomURL string // Engine, URL, Token are forwarded to carrier.Config for the "none" auth // carrier (direct engine access without a service-specific auth flow). Engine string @@ -63,6 +71,7 @@ type Config struct { SEIBatchSize int SEIFragmentSize int SEIAckTimeoutMS int + Traffic TrafficConfig } // Factory creates a transport instance. @@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) { if !ok { return nil, ErrTransportNotFound } - return factory(ctx, cfg) + tr, err := factory(ctx, cfg) + if err != nil { + return nil, err + } + return WithTraffic(tr, cfg.Traffic), nil } // Available returns a list of registered transport names. From 79c151126892905d140d1a0bedb3e46dd02722ba Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 02:40:17 +0300 Subject: [PATCH 8/8] Fix seichannel readiness before sending --- internal/e2e/tunnel_test.go | 18 +++++++++++--- internal/transport/seichannel/transport.go | 24 +++++++++++++++---- .../transport/seichannel/transport_test.go | 10 ++++++++ .../seichannel/transport_unit_test.go | 6 ++++- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index b5cf0dd..deb9f44 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -1021,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { if err := ln.Connect(context.Background()); err != nil { t.Fatalf("Connect() error = %v", err) } - if !ln.CanSend() { - t.Fatal("CanSend() = false, want true") - } + assertLinkCanSendAfterConnect(t, ln, transportName) if err := ln.Close(); err != nil { t.Fatalf("Close() error = %v", err) } @@ -1033,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { } } +func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) { + t.Helper() + + if transportName == transportSEI { + if ln.CanSend() { + t.Fatal("CanSend() = true before peer seichannel frame") + } + return + } + if !ln.CanSend() { + t.Fatal("CanSend() = false, want true") + } +} + //nolint:cyclop // table-driven test naturally has many branches func TestRealProviderTransportMatrix(t *testing.T) { if !*realE2E { diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index 6cb7f9b..73b54f9 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -35,6 +35,7 @@ const ( protocolVersion byte = 1 frameTypeData byte = 1 frameTypeAck byte = 2 + frameTypeHello byte = 3 ) var ( @@ -86,6 +87,7 @@ type streamTransport struct { nextSeq atomic.Uint32 closed atomic.Bool writerUp atomic.Bool + peerReady atomic.Bool sendMu sync.Mutex startWriter sync.Once ackMu sync.Mutex @@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) { // CanSend reports whether transport is ready for sending. func (p *streamTransport) CanSend() bool { - return !p.closed.Load() && p.stream.CanSend() + return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend() } // Features describes the current seichannel transport semantics. @@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() { ticker := time.NewTicker(p.effectiveFrameInterval()) defer ticker.Stop() - idle := buildVideoAccessUnit(nil) + idle := buildVideoAccessUnit(encodeHelloFrame()) for { select { @@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) { } switch frame.typ { + case frameTypeHello: + p.peerReady.Store(true) case frameTypeAck: + p.peerReady.Store(true) p.resolveAck(frame.seq, frame.crc) case frameTypeData: + p.peerReady.Store(true) p.handleInboundFrame(frame) } } @@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload out[5] = frameTypeData binary.BigEndian.PutUint32(out[6:10], seq) binary.BigEndian.PutUint32(out[10:14], crc) - binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic - binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic copy(out[22:], payload) return out @@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte { return out } +func encodeHelloFrame() []byte { + out := make([]byte, 6) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeHello + return out +} + func decodeTransportFrame(data []byte) (transportFrame, error) { if len(data) < 6 { return transportFrame{}, ErrFrameTooShort @@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) { frame := transportFrame{typ: data[5]} switch frame.typ { + case frameTypeHello: + return frame, nil case frameTypeAck: if len(data) < 14 { return transportFrame{}, ErrAckTooShort diff --git a/internal/transport/seichannel/transport_test.go b/internal/transport/seichannel/transport_test.go index 8f11c6f..51c8272 100644 --- a/internal/transport/seichannel/transport_test.go +++ b/internal/transport/seichannel/transport_test.go @@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) { t.Fatalf("payload mismatch: got=%q", decoded.payload) } } + +func TestHelloFrameRoundTrip(t *testing.T) { + hello, err := decodeTransportFrame(encodeHelloFrame()) + if err != nil { + t.Fatalf("decodeTransportFrame(hello) failed: %v", err) + } + if hello.typ != frameTypeHello { + t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello) + } +} diff --git a/internal/transport/seichannel/transport_unit_test.go b/internal/transport/seichannel/transport_unit_test.go index 00abf58..716b970 100644 --- a/internal/transport/seichannel/transport_unit_test.go +++ b/internal/transport/seichannel/transport_unit_test.go @@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) { if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched { t.Fatal("callbacks/watch were not forwarded") } + if tr.CanSend() { + t.Fatal("CanSend() = true before peer hello") + } + tr.handleSample(buildVideoAccessUnit(encodeHelloFrame())) if !tr.CanSend() { - t.Fatal("CanSend() = false, want true") + t.Fatal("CanSend() = false after peer hello") } if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description t.Fatalf("Features() = %+v", features)