From a86f5c6948c1dddab4a84d59b64729029caf64d7 Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Fri, 15 May 2026 23:49:14 +0300 Subject: [PATCH] feat: add reconnect hardening and failover profiles --- .gitignore | 1 + cmd/olcrtc/main.go | 115 ++++++- cmd/olcrtc/main_test.go | 119 +++++++ docs/about.md | 4 +- docs/client.example.yaml | 1 + docs/configuration.md | 54 +++- docs/failover.example.yaml | 34 ++ docs/manual.md | 2 +- docs/project-map.md | 400 ++++++++++++++++++++++++ docs/server.example.yaml | 1 + docs/settings.md | 20 +- internal/app/session/session.go | 158 ++++++++-- internal/app/session/session_test.go | 112 +++++++ internal/config/config.go | 160 +++++++++- internal/config/config_test.go | 152 +++++++++ internal/e2e/tunnel_test.go | 193 ++++++++++++ internal/engine/livekit/livekit.go | 385 +++++++++++++++++++---- internal/engine/livekit/livekit_test.go | 306 ++++++++++++++++++ internal/supervisor/supervisor.go | 96 ++++++ internal/supervisor/supervisor_test.go | 85 +++++ 20 files changed, 2280 insertions(+), 118 deletions(-) create mode 100644 docs/failover.example.yaml create mode 100644 docs/project-map.md create mode 100644 internal/engine/livekit/livekit_test.go create mode 100644 internal/supervisor/supervisor.go create mode 100644 internal/supervisor/supervisor_test.go diff --git a/.gitignore b/.gitignore index d0b6a3c..61fcb93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Prerequisites *.d +.DS_Store # Object files *.o diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index b8c2bdf..777949b 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -24,6 +24,7 @@ import ( configpkg "github.com/openlibrecommunity/olcrtc/internal/config" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" "github.com/openlibrecommunity/olcrtc/internal/transport/videochannel" ) @@ -35,6 +36,9 @@ var ErrConfigPathRequired = errors.New("usage: olcrtc ") // ErrDataDirRequired is returned when the YAML config does not specify a data directory. var ErrDataDirRequired = errors.New("data directory required (set 'data:' in YAML)") +// ErrProfilesUnsupportedForGen is returned when failover profiles are configured for gen mode. +var ErrProfilesUnsupportedForGen = errors.New("profiles are only supported for srv and cnc modes") + //nolint:gochecknoglobals // Tests replace the long-running session runner with a bounded function. var runSession = session.Run @@ -44,11 +48,18 @@ var runGen = execGen // loadedConfig bundles the parsed YAML file and the derived session config. type loadedConfig struct { scfg session.Config + profiles []supervisor.Profile + failover failoverConfig dataDir string debug bool ffmpegPath string } +type failoverConfig struct { + retryDelay time.Duration + maxCycles int +} + func main() { if err := run(); err != nil { logger.Error(err) @@ -79,14 +90,44 @@ func loadConfig(path string) (loadedConfig, error) { if err != nil { return loadedConfig{}, fmt.Errorf("load config: %w", err) } + base := configpkg.Apply(session.Config{}, f) + profiles := make([]supervisor.Profile, 0, len(f.Profiles)) + for i, profile := range f.Profiles { + name := profile.Name + if name == "" { + name = fmt.Sprintf("profile-%d", i+1) + } + profiles = append(profiles, supervisor.Profile{ + Name: name, + Config: configpkg.ApplyProfile(base, profile), + }) + } + failover, err := parseFailoverConfig(f.Failover) + if err != nil { + return loadedConfig{}, err + } return loadedConfig{ - scfg: configpkg.Apply(session.Config{}, f), + scfg: base, + profiles: profiles, + failover: failover, dataDir: f.Data, debug: f.Debug, ffmpegPath: f.FFmpeg, }, nil } +func parseFailoverConfig(f configpkg.Failover) (failoverConfig, error) { + retryDelay := supervisor.DefaultRetryDelay + if f.RetryDelay != "" { + parsed, err := time.ParseDuration(f.RetryDelay) + if err != nil { + return failoverConfig{}, fmt.Errorf("parse failover.retry_delay: %w", err) + } + retryDelay = parsed + } + return failoverConfig{retryDelay: retryDelay, maxCycles: f.MaxCycles}, nil +} + func runWithConfig(cfg loadedConfig) error { configureLogging(cfg.debug) @@ -98,19 +139,85 @@ func runWithConfig(cfg loadedConfig) error { if err != nil { return fmt.Errorf("validate config: %w", err) } + scfg = session.ApplyTransportDefaults(scfg) if scfg.Mode == modeGen { + if len(cfg.profiles) > 0 { + return ErrProfilesUnsupportedForGen + } return runGen(scfg) } + if len(cfg.profiles) > 0 { + profiles, err := prepareProfiles(cfg.profiles) + if err != nil { + return err + } + return runFailoverSessionMode(cfg.dataDir, profiles, cfg.failover) + } + return runSessionMode(cfg.dataDir, scfg) } +func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error) { + out := make([]supervisor.Profile, 0, len(profiles)) + for _, profile := range profiles { + scfg, err := session.ApplyAuthDefaults(profile.Config) + if err != nil { + return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err) + } + profile.Config = session.ApplyTransportDefaults(scfg) + out = append(out, profile) + } + return out, nil +} + func runSessionMode(dataDir string, scfg session.Config) error { if err := session.Validate(scfg); err != nil { return fmt.Errorf("validate config: %w", err) } + if err := prepareRuntimeData(dataDir); err != nil { + return err + } + + return runManaged(func(ctx context.Context) error { + return runSession(ctx, scfg) + }) +} + +func runFailoverSessionMode(dataDir string, profiles []supervisor.Profile, failover failoverConfig) error { + for _, profile := range profiles { + if err := session.Validate(profile.Config); err != nil { + return fmt.Errorf("validate profile %q: %w", profile.Name, err) + } + } + + if err := prepareRuntimeData(dataDir); err != nil { + return err + } + + return runManaged(func(ctx context.Context) error { + return supervisor.Run(ctx, supervisor.Config{ + Profiles: profiles, + RetryDelay: failover.retryDelay, + MaxCycles: failover.maxCycles, + OnProfileStart: func(profile supervisor.Profile, cycle int) { + logger.Infof("failover cycle=%d starting profile=%s carrier=%s transport=%s", + cycle, profile.Name, profile.Config.Auth, profile.Config.Transport) + }, + OnProfileEnd: func(profile supervisor.Profile, cycle int, err error) { + if err != nil { + logger.Warnf("failover cycle=%d profile=%s ended with error: %v", cycle, profile.Name, err) + return + } + logger.Warnf("failover cycle=%d profile=%s ended", cycle, profile.Name) + }, + }, runSession) + }) +} + +func prepareRuntimeData(dataDir string) error { if dataDir == "" { return ErrDataDirRequired } @@ -124,6 +231,10 @@ func runSessionMode(dataDir string, scfg session.Config) error { return err } + return nil +} + +func runManaged(run func(context.Context) error) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -132,7 +243,7 @@ func runSessionMode(dataDir string, scfg session.Config) error { errCh := make(chan error, 1) go func() { - errCh <- runSession(ctx, scfg) + errCh <- run(ctx) }() select { diff --git a/cmd/olcrtc/main_test.go b/cmd/olcrtc/main_test.go index acb6a1d..96a4aeb 100644 --- a/cmd/olcrtc/main_test.go +++ b/cmd/olcrtc/main_test.go @@ -9,6 +9,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/logger" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" ) var errBoom = errors.New("boom") @@ -149,6 +150,112 @@ data: `+dir+` } } +func TestRunWithArgsAppliesTransportDefaults(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil { + t.Fatalf("WriteFile(names) error = %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil { + t.Fatalf("WriteFile(surnames) error = %v", err) + } + + oldRunSession := runSession + t.Cleanup(func() { runSession = oldRunSession }) + runSession = func(ctx context.Context, cfg session.Config) error { + if cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1 { + t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize) + } + return nil + } + + yamlPath := writeYAML(t, ` +mode: srv +link: direct +auth: + provider: wbstream +room: + id: room +crypto: + key: key +net: + transport: vp8channel + dns: 1.1.1.1:53 +data: `+dir+` +`) + + if err := runWithArgs([]string{yamlPath}); err != nil { + t.Fatalf("runWithArgs() error = %v", err) + } +} + +func TestRunWithArgsFailoverProfiles(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil { + t.Fatalf("WriteFile(names) error = %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil { + t.Fatalf("WriteFile(surnames) error = %v", err) + } + + oldRunSession := runSession + t.Cleanup(func() { runSession = oldRunSession }) + var seen []string + runSession = func(ctx context.Context, cfg session.Config) error { + seen = append(seen, cfg.Auth+"/"+cfg.Transport) + if cfg.Auth == "wbstream" && (cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1) { + t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize) + } + return errBoom + } + + yamlPath := writeYAML(t, ` +mode: srv +link: direct +crypto: + key: key +net: + dns: 1.1.1.1:53 +profiles: + - name: wb-primary + auth: + provider: wbstream + room: + id: room + net: + transport: vp8channel + - name: jitsi-backup + auth: + provider: jitsi + room: + id: https://meet.example/room + net: + transport: datachannel +failover: + retry_delay: -1ns + max_cycles: 1 +data: `+dir+` +`) + + err := runWithArgs([]string{yamlPath}) + if !errors.Is(err, supervisor.ErrMaxCyclesExceeded) { + t.Fatalf("runWithArgs() error = %v, want %v", err, supervisor.ErrMaxCyclesExceeded) + } + want := []string{"wbstream/vp8channel", "jitsi/datachannel"} + if !equalStrings(seen, want) { + t.Fatalf("seen profiles = %v, want %v", seen, want) + } +} + +func TestRunWithConfigRejectsProfilesInGenMode(t *testing.T) { + cfg := loadedConfig{ + scfg: session.Config{Mode: modeGen}, + profiles: []supervisor.Profile{{Name: "one"}}, + } + if err := runWithConfig(cfg); !errors.Is(err, ErrProfilesUnsupportedForGen) { + t.Fatalf("runWithConfig() error = %v, want %v", err, ErrProfilesUnsupportedForGen) + } +} + func TestConfigureLogging(t *testing.T) { t.Setenv("PION_LOG_DISABLE", "") logger.SetVerbose(false) @@ -170,6 +277,18 @@ func TestConfigureLogging(t *testing.T) { } } +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + func TestResolveDataDir(t *testing.T) { abs := filepath.Join(t.TempDir(), "data") got, err := resolveDataDir(abs) diff --git a/docs/about.md b/docs/about.md index 112c6bb..a67149d 100644 --- a/docs/about.md +++ b/docs/about.md @@ -234,7 +234,7 @@ internal/e2e/ E2E тесты на реальных провайдер | Файл | Что делает | |---|---| -| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для jazz с ретраями (wbstream больше не поддерживает автогенерацию - руму нужно создавать вручную через stream.wb.ru) | +| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для auth-провайдеров с `RoomCreator` и ретраями | | `session_test.go` | Тесты валидации конфига | ### `internal/config/` @@ -452,7 +452,7 @@ Carrier - это WebRTC сервис видеозвонков, через кот - Минимальная прослойка, почти прямой relay - Работает с vp8channel, seichannel, videochannel - DataChannel **не работает** в обычном guest flow: WB Stream выдаёт токены с `canPublishData=false`, DC не маршрутизирует данные (expected fail в E2E тестах) -- Room ID нужно создавать вручную через stream.wb.ru +- Room ID можно создать вручную через stream.wb.ru или через `mode: gen` - Инициализация звонка автоматически --- diff --git a/docs/client.example.yaml b/docs/client.example.yaml index 5ec8792..fe83e0d 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -14,6 +14,7 @@ room: id: "https://meet.cryptopro.ru/REPLACE_WITH_ROOM_NAME" crypto: + # Or use key_file: "./olcrtc.key" to keep the secret out of this file. key: "REPLACE_ME_WITH_64_HEX_CHARS" # must match the server net: diff --git a/docs/configuration.md b/docs/configuration.md index 97d77fd..46edd07 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -11,6 +11,7 @@ olcrtc /etc/olcrtc/server.yaml - [`server.example.yaml`](./server.example.yaml) - [`client.example.yaml`](./client.example.yaml) +- [`failover.example.yaml`](./failover.example.yaml) ## Схема @@ -20,7 +21,7 @@ olcrtc /etc/olcrtc/server.yaml | `link` | `direct` | | `auth.provider` | `jitsi`, `telemost`, `jazz`, `wbstream`, `none` | | `room.id` | conference room id | -| `crypto.key` | 64-char hex (32 bytes) | +| `crypto.key` / `crypto.key_file` | 64-char hex (32 bytes), inline or read from file | | `net.transport` | `datachannel`, `videochannel`, `seichannel`, `vp8channel` | | `net.dns` | resolver `host:port` | | `socks.host` / `.port` | client-side listener | @@ -31,6 +32,57 @@ olcrtc /etc/olcrtc/server.yaml | `vp8.*` | vp8channel tuning | | `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning | | `gen.amount` | gen mode: number of rooms to create | +| `profiles[]` | ordered srv/cnc failover profiles | +| `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | +| `failover.max_cycles` | stop after N full profile-list passes; `0` = forever | | `data` | path to data directory | | `debug` | verbose logging | | `ffmpeg` | path to ffmpeg binary | + +`mode: cnc` refuses non-loopback `socks.host` values unless both +`socks.user` and `socks.pass` are set. + +`crypto.key_file` is resolved relative to the YAML file. Do not set it +together with `crypto.key`. + +## Failover Profiles + +`mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used +as common defaults; each profile overrides only the fields it sets. The CLI +runs profiles in order. If a profile fails or ends while the process is still +alive, olcrtc waits `failover.retry_delay` and starts the next profile. + +```yaml +mode: srv +link: direct +crypto: + key_file: ./olcrtc.key +net: + dns: "1.1.1.1:53" +data: data + +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: "WB_ROOM_ID" + net: + transport: vp8channel + + - name: jitsi-dc + auth: + provider: jitsi + room: + id: "https://meet.example.org/olcrtc-room" + net: + transport: datachannel + +failover: + retry_delay: 2s + max_cycles: 0 +``` + +Both peers must use compatible profile order and room settings. This first +failover layer rebuilds the session on the next profile; active smux streams +do not migrate, but new connections can recover on the next profile. diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml new file mode 100644 index 0000000..7aa8149 --- /dev/null +++ b/docs/failover.example.yaml @@ -0,0 +1,34 @@ +# olcrtc failover config example +# Use the same profile order on both peers. + +mode: srv +link: direct + +crypto: + key_file: "./olcrtc.key" + +net: + dns: "1.1.1.1:53" + +data: data + +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: "REPLACE_WITH_WB_ROOM_ID" + net: + transport: vp8channel + + - name: jitsi-datachannel + auth: + provider: jitsi + room: + id: "https://meet.example.org/REPLACE_WITH_ROOM_NAME" + net: + transport: datachannel + +failover: + retry_delay: 2s + max_cycles: 0 diff --git a/docs/manual.md b/docs/manual.md index a2a3a21..d623d86 100644 --- a/docs/manual.md +++ b/docs/manual.md @@ -177,7 +177,7 @@ data: data ### wbstream + vp8channel (альтернатива) -Сначала создай руму вручную через сайт [wbstream](https://stream.wb.ru) (автогенерация через `mode: gen` для wbstream больше не поддерживается) и сохрани её ID. +Создай руму через сайт [wbstream](https://stream.wb.ru) или заранее сгенерируй ID через `mode: gen` с `auth.provider: wbstream`. `wbstream + datachannel` **не работает** в обычном guest flow — WB Stream выдаёт токены с `canPublishData=false`, и DC не маршрутизирует данные. Для обычного использования выбирай `vp8channel`. diff --git a/docs/project-map.md b/docs/project-map.md new file mode 100644 index 0000000..c4c8791 --- /dev/null +++ b/docs/project-map.md @@ -0,0 +1,400 @@ +# olcRTC Project Map + +This is a developer map for finding the useful parts of the project quickly. +It focuses on code ownership, runtime flow, extension points, and areas that +are worth deeper work. + +## One-Sentence Model + +olcRTC is an encrypted TCP-over-WebRTC tunnel: the client exposes a local +SOCKS5 listener, the server dials requested TCP targets, and both sides carry +the smux byte stream through a selected WebRTC carrier and transport. + +## Runtime Stack + +```text +YAML config + -> cmd/olcrtc + -> internal/config + -> internal/app/session + -> internal/server or internal/client + -> internal/link/direct + -> internal/transport/{datachannel,vp8channel,seichannel,videochannel} + -> internal/carrier/builtin + -> internal/auth/ + internal/engine/ + -> external service SFU / signaling +``` + +Tunnel data path: + +```text +local app + -> client SOCKS5 + -> smux stream + -> muxconn AEAD encrypt + -> link.Send + -> transport encoding + -> carrier/engine + -> SFU/service + -> peer engine/carrier + -> transport decoding + -> muxconn AEAD decrypt + -> smux stream + -> server TCP dial + -> target host +``` + +## Entrypoints + +| Path | Purpose | +|---|---| +| `cmd/olcrtc/main.go` | Main CLI. Accepts one YAML file, applies auth and transport defaults, starts `srv`, `cnc`, or `gen`. | +| `cmd/olcrtc-cgo/main.go` | Small c-shared entrypoint for desktop/native consumers. | +| `pkg/olcrtc` | Embeddable lower-level API that returns a `net.Conn`-like handle over an engine data path. | +| `pkg/olcrtc/tunnel` | Embeddable server-side tunnel API with auth and traffic hooks. | +| `mobile/mobile.go` | gomobile API for Android clients, including VPN socket protection. | +| `script/srv.sh`, `script/cnc.sh` | Interactive shell launchers that generate YAML and run/build the app. | +| `Dockerfile`, `script/docker/*` | Container build and server entrypoint/healthcheck. | + +## Config And Session Layer + +`internal/config` owns YAML parsing and file-backed secret loading. + +Important fields: + +| YAML | Runtime field | Notes | +|---|---|---| +| `mode` | `session.Config.Mode` | `srv`, `cnc`, or `gen`. | +| `auth.provider` | `Auth` | `jitsi`, `telemost`, `jazz`, `wbstream`, or `none`. | +| `room.id` | `RoomID` | Carrier-specific room reference. | +| `crypto.key` / `crypto.key_file` | `KeyHex` | Shared 32-byte key encoded as 64 hex chars. | +| `net.transport` | `Transport` | `datachannel`, `vp8channel`, `seichannel`, or `videochannel`. | +| `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. | +| `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | +| `engine.*` | direct engine fields | Used only with `auth.provider: none`. | + +`internal/app/session` is the main router: + +1. Registers built-ins via `RegisterDefaults`. +2. Applies auth defaults: auth provider decides engine and default service URL. +3. Applies transport defaults: documented defaults for `vp8`, `sei`, and `video`. +4. Validates mode, auth, link, transport, room, key, DNS, transport options, and SOCKS listener safety. +5. Runs `server.Run`, `client.Run`, or `Gen`. + +## Server Side + +`internal/server` accepts encrypted smux sessions from the peer and proxies +each smux stream to a TCP target. + +Core pieces: + +| Symbol | Role | +|---|---| +| `server.Run` | Creates cipher, link, smux server, and serve loop. | +| `bringUpLink` | Builds `link.Link`, wires reconnect callbacks, connects carrier. | +| `installSession` / `reinstallSession` | Creates or replaces `muxconn + smux.Session`. | +| `acceptHandshake` | First smux stream; runs `handshake.Server`. | +| `handleStream` | Reads connect JSON and dispatches a tunnel stream. | +| `dispatch` | Dials target, sends ready byte, copies both directions. | +| `AuthHook` | Embedders can authorize clients after `CLIENT_HELLO`. | +| `OnSessionOpen`, `OnSessionClose`, `OnTraffic` | Observability hooks. | + +Server risk areas: + +- Target dialing is powerful by design. Any real product wrapper should add + an `AuthHook` and probably destination policy. +- `defaultAuthHook` admits everyone who knows the room and key. +- Reconnect rebuilds smux sessions; active streams are sacrificed. + +## Client Side + +`internal/client` exposes a local SOCKS5 listener and opens one smux stream +per SOCKS CONNECT request. + +Core pieces: + +| Symbol | Role | +|---|---| +| `RunWithReady` | Starts link, opens smux client, listens on local SOCKS. | +| `openControlStream` | First smux stream; runs `handshake.Client`. | +| `handleSocks5` | SOCKS method negotiation and CONNECT parsing. | +| `sendConnectRequest` | Sends server-side target JSON and waits for ready byte. | +| `handleReconnect` | Rebuilds smux and control stream after carrier reconnect. | +| `resolveDeviceID` | Optional persistent client identity for hooks. | + +Client risk areas: + +- A non-loopback SOCKS listener must require `socks.user` and `socks.pass`. +- SOCKS credentials are simple static credentials, not a full account system. +- Existing streams do not survive reconnect; new SOCKS connections can recover. + +## Wire Protocol Above WebRTC + +`internal/muxconn` adapts `link.Link` to `io.ReadWriteCloser`. + +- Every smux write is encrypted with `internal/crypto`. +- Every inbound link message is decrypted and appended to an internal byte buffer. +- Bad AEAD frames are dropped. +- `CanSend` provides backpressure before encrypting and sending. + +`internal/crypto` uses XChaCha20-Poly1305 with a random nonce prepended to +each ciphertext. + +`internal/handshake` runs on the first smux stream: + +```text +CLIENT_HELLO { version, device_id, claims } +SERVER_WELCOME { version, session_id } +or +SERVER_REJECT { version, reason } +``` + +The handshake has a 64 KiB frame cap and a default 15 second timeout. + +## Registries And Plugin Shape + +The universal-carrier refactor centers on small registries: + +| Registry | Package | Registers | +|---|---|---| +| Auth providers | `internal/auth` | Service-specific credential and room creation flows. | +| Engines | `internal/engine` | Wire-level SFU protocol implementations. | +| Carriers | `internal/carrier` | Auth + engine adapters exposed as byte/video capability providers. | +| Transports | `internal/transport` | Byte transport strategy over carrier primitives. | +| Links | `internal/link` | Higher-level link abstraction; currently only `direct`. | + +`internal/carrier/builtin` connects the auth and engine worlds: + +```text +carrier "wbstream" -> auth/wbstream -> engine/livekit +carrier "jazz" -> auth/salutejazz -> engine/salutejazz +carrier "telemost"-> auth/telemost -> engine/goolom +carrier "jitsi" -> auth/jitsi -> engine/jitsi +carrier "none" -> direct user-supplied engine/url/token +``` + +## Auth Providers + +| Provider | Engine | Room generation | Notes | +|---|---|---:|---| +| `jitsi` | `jitsi` | No | Parses host/room from a public or self-hosted Jitsi URL. No HTTP auth. | +| `telemost` | `goolom` | No | Calls Telemost room-info flow and returns Goolom credentials. | +| `wbstream` | `livekit` | Yes | Registers guest, optionally creates room, joins room, fetches LiveKit token. | +| `jazz` / `salutejazz` | `salutejazz` | Yes | Creates or joins SaluteJazz room and returns room/password tuple. | +| `none` | chosen by config | No | Direct engine mode for downstream tools or self-hosted SFUs. | + +## Engines + +Engines expose the low-level service/SFU protocol. + +| Engine | Package | Byte stream | Video track | Main job | +|---|---|---:|---:|---| +| `livekit` | `internal/engine/livekit` | Yes | Yes | LiveKit SDK room, data packets, local/remote tracks, reconnect with credential refresh. | +| `goolom` | `internal/engine/goolom` | Yes | Yes | Yandex Telemost/Goolom signaling, split publisher/subscriber peer connections, telemetry/keepalive. | +| `jitsi` | `internal/engine/jitsi` | Yes | Best effort | Jitsi MUC/Jingle/colibri-ws plus optional video track negotiation. | +| `salutejazz` | `internal/engine/salutejazz` | Yes | Yes | SaluteJazz WebSocket signaling and split media peer connections. | + +Engine work is where most provider breakage and reconnect complexity lives. + +## Transports + +Transports decide how raw tunnel bytes are carried once the carrier provides +either a byte stream or a video track. + +| Transport | Primitive | Reliability model | Best fit | Notes | +|---|---|---|---|---| +| `datachannel` | Carrier byte stream | Native reliable ordered messages | Jitsi, direct engines, some Jazz cases | Simple pass-through with 12 KiB message cap. | +| `vp8channel` | VP8 video track | KCP over VP8-looking frames | WB Stream and Telemost-style video paths | Highest-performance video-path transport. Uses epochs and binding tokens to survive restarts/loopback. | +| `seichannel` | H264 SEI video track | Custom fragments + ACK/retry | WB Stream fallback | Carries data in SEI NAL units with fragmentation, CRC, ACK. | +| `videochannel` | Visual frames via ffmpeg | QR/tile frames + ACK/retry | Experimental/inspection-friendly path | Encodes visual payload frames, requires ffmpeg, supports QR and tile codecs. | + +Transport work is where throughput, loss recovery, and adaptive tuning should +happen. + +## Public/Embedding Surfaces + +| Package | User | +|---|---| +| `pkg/olcrtc` | Go programs that want a `net.Conn` over a selected auth/engine. | +| `pkg/olcrtc/tunnel` | Go programs that want to embed the server-side tunnel with auth/traffic hooks. | +| `mobile` | Android app bindings. Wraps client mode, VPN socket protection, logging, simple health checks. | +| `cmd/olcrtc-cgo` | Native desktop/client integrations using c-shared Go export. | + +These surfaces are important if the CLI becomes only one frontend among many. + +## Tests + +The project has broad unit coverage: + +- Config/session validation and defaults. +- Auth provider HTTP flows with test servers. +- Engine helper logic and reconnect paths. +- SOCKS parsing, smux handshake, server dispatch. +- Crypto, muxconn, names, protect, logging. +- Transport frame codecs, ACK paths, KCP loopback, ffmpeg helpers. +- Memory-backed E2E tunnel tests and optional real-provider E2E matrix. + +Useful commands: + +```sh +go test -count=1 ./... +go test -race -count=1 ./cmd/olcrtc ./internal/app/session ./internal/config ./internal/engine/livekit +go test -race -count=1 -v ./internal/e2e +E2E_CARRIERS=wbstream E2E_TRANSPORTS=vp8channel mage e2e +go build -trimpath -o build/olcrtc ./cmd/olcrtc +``` + +## High-Value Coding Areas + +### 1. Supervisor And Multi-Profile Failover + +The first supervisor layer exists in `internal/supervisor`: the CLI can run a +prioritized list of carrier/transport profiles and move to the next profile +when the active one fails or ends. + +```yaml +mode: srv +link: direct +crypto: + key_file: ./olcrtc.key +net: + dns: "1.1.1.1:53" +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: WB_ROOM_ID + net: + transport: vp8channel + - name: jitsi-dc + auth: + provider: jitsi + room: + id: https://meet.example.org/olcrtc-room + net: + transport: datachannel +failover: + retry_delay: 2s + max_cycles: 0 +``` + +Implemented: + +- Config schema for `profiles[]`. +- Ordered supervisor loop. +- `failover.retry_delay`. +- `failover.max_cycles`. +- Profile start/end logs. + +Still valuable: + +- Health scoring per profile. +- Control-stream coordination before switching. +- Stream draining and migration instead of dropping active smux streams. +- Shared status output for the active profile and failover history. + +Likely files: + +- `internal/config/config.go` +- `internal/app/session/session.go` +- `internal/supervisor` +- `internal/server` +- `internal/client` +- `docs/configuration.md` +- `internal/e2e/tunnel_test.go` + +### 2. Transport Telemetry And Adaptive Tuning + +Add metrics from transport to link/session: + +- Send queue depth. +- ACK latency. +- Retries. +- Reconnect count. +- Dropped/decrypt-failed frames. +- KCP RTT/loss where available. + +Then make `vp8.batch_size`, `sei.fragment_size`, ACK timeout, and pacing +adaptive instead of static YAML knobs. + +### 3. Control Stream Protocol + +The first smux stream is parked after handshake. It is the natural place for: + +- Ping/pong and peer liveness. +- Server policy updates. +- Graceful reconnect notifications. +- Drain/start markers for failover. +- Per-session stats. + +Likely files: + +- `internal/handshake` +- `internal/server` +- `internal/client` + +### 4. Destination Policy And Real Auth + +The tunnel can dial arbitrary server-side TCP targets. A production wrapper +should use `AuthHook` and enforce: + +- Allowed destination CIDRs/domains/ports. +- Per-device or per-plan policy. +- Session expiration. +- Traffic accounting limits. +- Sanitized rejection reasons. + +This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`. + +### 5. Provider Hardening + +Provider APIs can drift. Worth adding: + +- Better typed errors from auth providers. +- Provider health probes. +- Fixture-based contract tests for API response changes. +- Per-provider rate/backoff policy. +- Safer secret/log redaction. + +Likely files: + +- `internal/auth/*` +- `internal/engine/*` +- `internal/carrier/builtin` + +### 6. Codebase Hygiene + +Some public-facing text and comments are not suitable for a serious external +project. Cleaning that up would improve maintainability and downstream trust. +The most obvious targets are top-level docs and a large hostile block comment +in `internal/transport/vp8channel/transport.go`. + +## Where To Look First + +| Goal | Start here | +|---|---| +| Change YAML schema | `internal/config/config.go`, `cmd/olcrtc/main.go`, docs examples. | +| Change validation/defaults | `internal/app/session/session.go`. | +| Add a new auth provider | `internal/auth`, then register in `internal/carrier/builtin/register.go`. | +| Add a new SFU protocol | `internal/engine`, then connect through auth/carrier. | +| Add a new byte transport | `internal/transport`, then register in `session.RegisterDefaults`. | +| Add link behavior above transports | `internal/link`; currently only `direct`. | +| Improve SOCKS behavior | `internal/client`. | +| Improve server target dialing or policy | `internal/server`, `pkg/olcrtc/tunnel`. | +| Improve reconnect | Engines first, then `internal/client` and `internal/server` smux rebuild behavior. | +| Improve Android app integration | `mobile`, `internal/protect`, `client.RunWithReady`. | + +## Mental Model For Big Changes + +Prefer to keep the layer boundaries: + +- Auth creates credentials; it should not know transport details. +- Engine speaks service/SFU protocol; it should not know SOCKS or smux. +- Carrier adapts auth+engine into byte/video capabilities. +- Transport turns byte/video capabilities into reliable-ish tunnel bytes. +- Link is policy above transport. +- Client/server own SOCKS, smux, handshake, target dialing, and session hooks. + +If a change crosses more than two layers, it probably deserves a new +orchestrator package instead of pushing more state into an engine or transport. diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 7a5f638..9f5ee38 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -16,6 +16,7 @@ room: crypto: # 32-byte hex (64 chars). Generate with: openssl rand -hex 32 + # Or use key_file: "./olcrtc.key" to keep the secret out of this file. key: "REPLACE_ME_WITH_64_HEX_CHARS" net: diff --git a/docs/settings.md b/docs/settings.md index f9d6c80..28855ce 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -48,7 +48,7 @@ | `auth.provider` | `telemost`, `jazz`, `wbstream` или `jitsi` | | `net.transport` | `datachannel`, `vp8channel`, `seichannel` или `videochannel` | | `room.id` | Room ID | -| `crypto.key` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` | +| `crypto.key` или `crypto.key_file` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` | | `link` | Всегда `direct` | | `data` | Всегда `data` | | `net.dns` | DNS-сервер, например `1.1.1.1:53` | @@ -60,18 +60,27 @@ | YAML поле | Описание | |-----------|----------| | `debug` | `true` для подробных логов соединений | +| `profiles` | Список профилей failover для `srv`/`cnc` | +| `failover.retry_delay` | Пауза перед следующим профилем, например `2s` | +| `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно | + +`crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. + +Если задан `profiles`, поля верхнего уровня становятся общими defaults, а +каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и +настройки транспорта. Порядок профилей должен совпадать на сервере и клиенте. --- ## mode: gen -Генерирует Room ID заранее, не запуская сервер. Поддерживается только для `jazz`. Для `wbstream` создавай руму вручную через [stream.wb.ru](https://stream.wb.ru) (автогенерация отключена со стороны WB). +Генерирует Room ID заранее, не запуская сервер. Поддерживается для auth-провайдеров с автосозданием комнат: `jazz` и `wbstream`. Для `telemost` комнату нужно создавать вручную через сайт. **Обязательные поля:** | YAML поле | Описание | |-----------|----------| -| `auth.provider` | `jazz` | +| `auth.provider` | `jazz` или `wbstream` | | `net.dns` | DNS-сервер | | `gen.amount` | Количество комнат | @@ -79,7 +88,7 @@ # gen.yaml mode: gen auth: - provider: jazz + provider: wbstream net: dns: "1.1.1.1:53" gen: @@ -116,6 +125,9 @@ gen: Если `socks.user` не задан - аутентификация отключена (любой локальный клиент может подключиться). Если задан - клиент принимает только подключения с правильным логином и паролем (RFC 1929). +Если `socks.host` не loopback (`127.0.0.1`, `::1`, `localhost`), `socks.user` и `socks.pass` обязательны. +Это защита от случайного открытого SOCKS5-прокси в локальной сети или интернете. + --- ## datachannel diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 89900bf..89de5f5 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "slices" "time" @@ -37,18 +38,33 @@ const ( videoCodecTile = "tile" ) +const ( + defaultVideoWidth = 1920 + defaultVideoHeight = 1080 + defaultVideoFPS = 30 + defaultVideoBitrate = "2M" + defaultVideoHW = "none" + defaultVideoQRRecovery = "low" + defaultVP8FPS = 25 + defaultVP8BatchSize = 1 + defaultSEIFPS = 60 + defaultSEIBatchSize = 64 + defaultSEIFragmentSize = 900 + defaultSEIAckTimeoutMS = 2000 +) + var ( // ErrRoomIDRequired indicates that a room id is required for the selected carrier. - ErrRoomIDRequired = errors.New("room ID required (use -id )") + ErrRoomIDRequired = errors.New("room ID required (set room.id)") // ErrModeRequired indicates that mode is not one of the supported values. - ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)") - // ErrAmountRequired indicates that -amount is required for gen mode. - ErrAmountRequired = errors.New("amount required for gen mode (use -amount )") + ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)") + // ErrAmountRequired indicates that gen.amount is required for gen mode. + ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)") // ErrAuthRequired indicates that no auth provider was selected. ErrAuthRequired = errors.New( - "auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)") - // ErrURLRequired indicates that -url must be provided when the auth provider has no default URL. - ErrURLRequired = errors.New("SFU URL required (use -url wss://...)") + "auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)") + // ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL. + ErrURLRequired = errors.New("SFU URL required (set auth.url)") // ErrUnsupportedCarrier indicates that carrier is not registered. ErrUnsupportedCarrier = errors.New("unsupported carrier") // ErrUnsupportedLink indicates that link is not registered. @@ -57,51 +73,53 @@ var ( ErrUnsupportedTransport = errors.New("unsupported transport") // ErrLinkRequired indicates that link is not provided. - ErrLinkRequired = errors.New("link required (use -link direct)") + ErrLinkRequired = errors.New("link required (set link to direct)") // ErrTransportRequired indicates that transport is not provided. ErrTransportRequired = errors.New( - "transport required (use -transport datachannel, -transport videochannel, " + - "-transport seichannel or -transport vp8channel)") + "transport required (set transport to datachannel, videochannel, seichannel or vp8channel)") // ErrKeyRequired indicates that encryption key is not provided. - ErrKeyRequired = errors.New("key required (use -key )") + ErrKeyRequired = errors.New("key required (set crypto.key)") // ErrDNSServerRequired indicates that dns server is not provided. - ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)") + ErrDNSServerRequired = errors.New("dns server required (set net.dns)") // ErrVideoWidthRequired indicates that video width is required for videochannel. - ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)") + ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)") // ErrVideoHeightRequired indicates that video height is required for videochannel. - ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)") + ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)") // ErrVideoFPSRequired indicates that video fps is required for videochannel. - ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)") + ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)") // ErrVideoBitrateRequired indicates that video bitrate is required for videochannel. ErrVideoBitrateRequired = errors.New( - "video bitrate required for videochannel (use -video-bitrate)") + "video bitrate required for videochannel (set video.bitrate)") // ErrVideoHWRequired indicates that video hardware acceleration is required. ErrVideoHWRequired = errors.New( - "video hardware acceleration required for videochannel (use -video-hw none/nvenc)") + "video hardware acceleration required for videochannel (set video.hw to none or nvenc)") // ErrVideoCodecInvalid indicates that the video codec is not valid. ErrVideoCodecInvalid = errors.New( - "invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)") + "invalid video codec for videochannel (set video.codec to qrcode or tile)") // ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions. - ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080") + ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080") // ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel. - ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)") + ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)") // ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel. - ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)") + ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)") // ErrSEIFPSRequired indicates that seichannel fps is required. - ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)") + ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)") // ErrSEIBatchSizeRequired indicates that seichannel batch size is required. - ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)") + ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)") // ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required. - ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)") + ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)") // ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required. - ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)") + ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)") // ErrSOCKSHostRequired indicates that socks host is required for cnc mode. - ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)") + ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)") // ErrSOCKSPortRequired indicates that socks port is required for cnc mode. - ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)") + ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)") + // ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication. + ErrSOCKSAuthRequired = errors.New( + "socks auth required when binding outside loopback (set socks.user and socks.pass)") ) // Config holds runtime session settings. @@ -180,6 +198,80 @@ func ApplyAuthDefaults(cfg Config) (Config, error) { return cfg, nil } +// ApplyTransportDefaults fills documented transport defaults without changing core routing fields. +func ApplyTransportDefaults(cfg Config) Config { + switch cfg.Transport { + case transportVideo: + return applyVideoDefaults(cfg) + case transportVP8: + return applyVP8Defaults(cfg) + case transportSEI: + return applySEIDefaults(cfg) + default: + return cfg + } +} + +func applyVideoDefaults(cfg Config) Config { + if cfg.VideoCodec == "" { + cfg.VideoCodec = videoCodecQRCode + } + if cfg.VideoCodec == videoCodecTile { + if cfg.VideoWidth == 0 { + cfg.VideoWidth = 1080 + } + if cfg.VideoHeight == 0 { + cfg.VideoHeight = 1080 + } + } else { + if cfg.VideoWidth == 0 { + cfg.VideoWidth = defaultVideoWidth + } + if cfg.VideoHeight == 0 { + cfg.VideoHeight = defaultVideoHeight + } + } + if cfg.VideoFPS == 0 { + cfg.VideoFPS = defaultVideoFPS + } + if cfg.VideoBitrate == "" { + cfg.VideoBitrate = defaultVideoBitrate + } + if cfg.VideoHW == "" { + cfg.VideoHW = defaultVideoHW + } + if cfg.VideoQRRecovery == "" { + cfg.VideoQRRecovery = defaultVideoQRRecovery + } + return cfg +} + +func applyVP8Defaults(cfg Config) Config { + if cfg.VP8FPS == 0 { + cfg.VP8FPS = defaultVP8FPS + } + if cfg.VP8BatchSize == 0 { + cfg.VP8BatchSize = defaultVP8BatchSize + } + return cfg +} + +func applySEIDefaults(cfg Config) Config { + if cfg.SEIFPS == 0 { + cfg.SEIFPS = defaultSEIFPS + } + if cfg.SEIBatchSize == 0 { + cfg.SEIBatchSize = defaultSEIBatchSize + } + if cfg.SEIFragmentSize == 0 { + cfg.SEIFragmentSize = defaultSEIFragmentSize + } + if cfg.SEIAckTimeoutMS == 0 { + cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS + } + return cfg +} + // Validate verifies that the runtime config refers to registered components and all required fields are present. func Validate(cfg Config) error { if err := validateMode(cfg); err != nil { @@ -333,11 +425,23 @@ func validateModeConfig(cfg Config) error { if cfg.SOCKSPort == 0 { return ErrSOCKSPortRequired } + if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") { + return ErrSOCKSAuthRequired + } return nil } +func isLoopbackListenHost(host string) bool { + if host == "localhost" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + // Run starts the configured mode. func Run(ctx context.Context, cfg Config) error { + cfg = ApplyTransportDefaults(cfg) roomURL := cfg.RoomID switch cfg.Mode { diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index 6ca3f79..f20e70d 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -6,6 +6,85 @@ import ( "testing" ) +func TestApplyTransportDefaults(t *testing.T) { + tests := []struct { + name string + in Config + want Config + }{ + { + name: "vp8", + in: Config{Transport: transportVP8}, + want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1}, + }, + { + name: "sei", + in: Config{Transport: transportSEI}, + want: Config{ + Transport: transportSEI, + SEIFPS: 60, + SEIBatchSize: 64, + SEIFragmentSize: 900, + SEIAckTimeoutMS: 2000, + }, + }, + { + name: "video qrcode", + in: Config{Transport: transportVideo}, + want: Config{ + Transport: transportVideo, + VideoWidth: 1920, + VideoHeight: 1080, + VideoFPS: 30, + VideoBitrate: "2M", + VideoHW: "none", + VideoQRRecovery: "low", + VideoCodec: videoCodecQRCode, + }, + }, + { + name: "video tile dimensions", + in: Config{Transport: transportVideo, VideoCodec: videoCodecTile}, + want: Config{ + Transport: transportVideo, + VideoWidth: 1080, + VideoHeight: 1080, + VideoFPS: 30, + VideoBitrate: "2M", + VideoHW: "none", + VideoQRRecovery: "low", + VideoCodec: videoCodecTile, + }, + }, + { + name: "keeps explicit values", + in: Config{ + Transport: transportSEI, + SEIFPS: 10, + SEIBatchSize: 2, + SEIFragmentSize: 300, + SEIAckTimeoutMS: 1500, + }, + want: Config{ + Transport: transportSEI, + SEIFPS: 10, + SEIBatchSize: 2, + SEIFragmentSize: 300, + SEIAckTimeoutMS: 1500, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ApplyTransportDefaults(tt.in) + if got != tt.want { + t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want) + } + }) + } +} + //nolint:maintidx // table-driven validation test naturally has many cases func TestValidate(t *testing.T) { RegisterDefaults() @@ -310,6 +389,39 @@ func TestValidate(t *testing.T) { }(), want: ErrSOCKSPortRequired, }, + { + name: "cnc rejects unauthenticated wildcard socks bind", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "0.0.0.0" + cfg.SOCKSPort = 1080 + return cfg + }(), + want: ErrSOCKSAuthRequired, + }, + { + name: "cnc allows authenticated wildcard socks bind", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "0.0.0.0" + cfg.SOCKSPort = 1080 + cfg.SOCKSUser = "user" + cfg.SOCKSPass = "pass" + return cfg + }(), + }, + { + name: "cnc allows localhost socks bind without auth", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "localhost" + cfg.SOCKSPort = 1080 + return cfg + }(), + }, } for _, tt := range tests { diff --git a/internal/config/config.go b/internal/config/config.go index 49b0f60..5fe206c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,9 @@ // Package config loads olcrtc runtime configuration from YAML files. // // The YAML schema mirrors [session.Config]. Fields left unset in the file -// remain at their zero value, allowing CLI flags to fill them in. Use -// [Apply] to merge a parsed [File] onto an existing [session.Config]; -// non-zero fields in the session config (typically populated from CLI flags) -// take precedence over the YAML values. +// remain at their zero value. Use [Apply] to map a parsed [File] onto an +// existing [session.Config]; non-zero fields in the session config take +// precedence over the YAML values. // //nolint:tagliatelle // YAML keys are the documented config file schema. package config @@ -13,17 +12,46 @@ import ( "errors" "fmt" "os" + "path/filepath" + "strings" "github.com/openlibrecommunity/olcrtc/internal/app/session" "gopkg.in/yaml.v3" ) -// ErrConfigNotFound is returned when a config file path is set but the file does not exist. -var ErrConfigNotFound = errors.New("config file not found") +var ( + // ErrConfigNotFound is returned when a config file path is set but the file does not exist. + ErrConfigNotFound = errors.New("config file not found") + // ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured. + ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set") + // ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file. + ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty") +) // File is the on-disk YAML schema. type File struct { - Mode string `yaml:"mode"` + Mode string `yaml:"mode"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Gen Gen `yaml:"gen"` + Profiles []Profile `yaml:"profiles"` + Failover Failover `yaml:"failover"` + Data string `yaml:"data"` + Debug bool `yaml:"debug"` + FFmpeg string `yaml:"ffmpeg"` +} + +// Profile is a failover entry that overrides top-level runtime fields. +type Profile struct { + Name string `yaml:"name"` Link string `yaml:"link"` Auth Auth `yaml:"auth"` Room Room `yaml:"room"` @@ -34,10 +62,12 @@ type File struct { Video Video `yaml:"video"` VP8 VP8 `yaml:"vp8"` SEI SEI `yaml:"sei"` - Gen Gen `yaml:"gen"` - Data string `yaml:"data"` - Debug bool `yaml:"debug"` - FFmpeg string `yaml:"ffmpeg"` +} + +// Failover controls ordered profile failover. +type Failover struct { + RetryDelay string `yaml:"retry_delay"` + MaxCycles int `yaml:"max_cycles"` } // Auth selects the auth provider. @@ -52,7 +82,8 @@ type Room struct { // Crypto holds the shared secret used to authenticate and encrypt the tunnel. type Crypto struct { - Key string `yaml:"key"` // 64-char hex (32 bytes) + Key string `yaml:"key"` // 64-char hex (32 bytes) + KeyFile string `yaml:"key_file"` // path to a file containing crypto.key } // Net groups network and transport selection. @@ -125,9 +156,63 @@ func Load(path string) (File, error) { if err := yaml.Unmarshal(data, &f); err != nil { return File{}, fmt.Errorf("parse config %s: %w", path, err) } + if err := loadExternalSecrets(path, &f); err != nil { + return File{}, err + } return f, nil } +func loadExternalSecrets(configPath string, f *File) error { + if f.Crypto.KeyFile == "" { + return loadProfileSecrets(configPath, f.Profiles) + } + if f.Crypto.Key != "" { + return ErrCryptoKeyConflict + } + + key, err := readKeyFile(configPath, f.Crypto.KeyFile) + if err != nil { + return err + } + f.Crypto.Key = key + return loadProfileSecrets(configPath, f.Profiles) +} + +func loadProfileSecrets(configPath string, profiles []Profile) error { + for i := range profiles { + if profiles[i].Crypto.KeyFile == "" { + continue + } + if profiles[i].Crypto.Key != "" { + return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict) + } + key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile) + if err != nil { + return fmt.Errorf("profiles[%d]: %w", i, err) + } + profiles[i].Crypto.Key = key + } + return nil +} + +func readKeyFile(configPath, keyFile string) (string, error) { + keyPath := keyFile + if !filepath.IsAbs(keyPath) { + keyPath = filepath.Join(filepath.Dir(configPath), keyPath) + } + + // #nosec G304 -- key_file is an explicit path in the user's config file. + data, err := os.ReadFile(keyPath) + if err != nil { + return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err) + } + key := strings.TrimSpace(string(data)) + if key == "" { + return "", ErrCryptoKeyFileEmpty + } + return key, nil +} + // Apply merges f onto dst. CLI-set fields (non-zero values in dst) win; // YAML values fill in the rest. func Apply(dst session.Config, f File) session.Config { @@ -167,6 +252,43 @@ func Apply(dst session.Config, f File) session.Config { return dst } +// ApplyProfile overlays a failover profile onto an already-applied base config. +func ApplyProfile(base session.Config, p Profile) session.Config { + dst := base + dst.Link = overlayString(dst.Link, p.Link) + dst.Transport = overlayString(dst.Transport, p.Net.Transport) + dst.Auth = overlayString(dst.Auth, p.Auth.Provider) + dst.Engine = overlayString(dst.Engine, p.Engine.Name) + dst.URL = overlayString(dst.URL, p.Engine.URL) + dst.Token = overlayString(dst.Token, p.Engine.Token) + dst.RoomID = overlayString(dst.RoomID, p.Room.ID) + dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key) + dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host) + dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port) + dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User) + dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass) + dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS) + dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr) + dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort) + dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width) + dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height) + dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS) + dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate) + dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW) + dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize) + dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery) + dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec) + dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule) + dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS) + dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS) + dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize) + dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS) + dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize) + dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize) + dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS) + return dst +} + func pickString(cli, yamlVal string) string { if cli != "" { return cli @@ -180,3 +302,17 @@ func pickInt(cli, yamlVal int) int { } return yamlVal } + +func overlayString(base, override string) string { + if override != "" { + return override + } + return base +} + +func overlayInt(base, override int) int { + if override != 0 { + return override + } + return base +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 95c4d9b..7504110 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "path/filepath" "testing" @@ -121,6 +122,157 @@ func TestApplyCLIWins(t *testing.T) { } } +func TestLoadAndApplyProfile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +mode: srv +link: direct +crypto: + key: shared-key +net: + dns: 1.1.1.1:53 +profiles: + - name: wb-vp8 + auth: + provider: wbstream + room: + id: wb-room + net: + transport: vp8channel + vp8: + fps: 30 + - name: jitsi-dc + auth: + provider: jitsi + room: + id: https://meet.example/room + net: + transport: datachannel + dns: 8.8.8.8:53 +failover: + retry_delay: 100ms + max_cycles: 2 +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(f.Profiles) != 2 { + t.Fatalf("profiles = %d, want 2", len(f.Profiles)) + } + if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 { + t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover) + } + + base := Apply(session.Config{}, f) + first := ApplyProfile(base, f.Profiles[0]) + if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" { + t.Fatalf("first profile = %+v", first) + } + if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 { + t.Fatalf("first inherited/overlaid fields = %+v", first) + } + second := ApplyProfile(base, f.Profiles[1]) + if second.Auth != "jitsi" || second.Transport != "datachannel" || + second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" { + t.Fatalf("second profile = %+v", second) + } +} + +func TestLoadProfileCryptoKeyFile(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +profiles: + - name: file-key + crypto: + key_file: profile.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := f.Profiles[0].Crypto.Key; got != testCryptoKey { + t.Fatalf("profile key = %q, want %q", got, testCryptoKey) + } +} + +func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "secret.key") + if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +mode: srv +crypto: + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + f, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if f.Crypto.Key != testCryptoKey { + t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey) + } +} + +func TestLoadCryptoKeyFileConflict(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +crypto: + key: deadbeef + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if !errors.Is(err, ErrCryptoKeyConflict) { + t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict) + } +} + +func TestLoadCryptoKeyFileEmpty(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "secret.key") + if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + path := filepath.Join(dir, "olcrtc.yaml") + body := ` +crypto: + key_file: secret.key +` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if !errors.Is(err, ErrCryptoKeyFileEmpty) { + t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty) + } +} + func TestLoadMissing(t *testing.T) { _, err := Load(filepath.Join(t.TempDir(), "nope.yaml")) if err == nil { diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 835bd65..b5cf0dd 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -24,6 +24,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/server" + "github.com/openlibrecommunity/olcrtc/internal/supervisor" "github.com/pion/webrtc/v4" ) @@ -47,6 +48,7 @@ var ( errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply") errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting") errPayloadMismatchOffset = errors.New("payload mismatch at offset") + errFailoverCarrier = errors.New("intentional failover carrier failure") ) var ( @@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) { }) } +func registerFailingCarrier(t *testing.T) string { + t.Helper() + session.RegisterDefaults() + + name := "e2e-fail-" + t.Name() + carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { + return nil, errFailoverCarrier + }) + return name +} + func builtInCarrierNames() []string { return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional } @@ -1163,6 +1176,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) { } } +func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) { + echoAddr := startEchoServer(t) + failingCarrier := registerFailingCarrier(t) + memoryCarrier, room := registerMemoryCarrier(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + socksAddr := freeLocalAddr(ctx, t) + socksHost, socksPort := splitHostPort(t, socksAddr) + + serverProfiles := []supervisor.Profile{ + {Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)}, + {Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)}, + } + clientProfiles := []supervisor.Profile{ + {Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)}, + {Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)}, + } + + started := make(chan string, 8) + serverErr := make(chan error, 1) + go func() { + serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run) + }() + room.waitConnected(t, 1) + + ready := make(chan struct{}) + var readyOnce sync.Once + clientErr := make(chan error, 1) + go func() { + clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error { + return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() { + if cfg.Auth == memoryCarrier { + readyOnce.Do(func() { close(ready) }) + } + }) + }) + }() + + waitForReady(t, ready) + conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr) + defer func() { _ = conn.Close() }() + + payload := []byte("olcrtc-failover-e2e\n") + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write failover payload: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + t.Fatalf("set failover read deadline: %v", err) + } + line, err := bufio.NewReader(conn).ReadBytes('\n') + if err != nil { + t.Fatalf("read failover echo: %v", err) + } + if !bytes.Equal(line, payload) { + t.Fatalf("failover echo = %q, want %q", line, payload) + } + + requireStartedProfiles(t, started, []string{ + "server:failing-server", + "server:memory-server", + "client:failing-client", + "client:memory-client", + }) + + cancel() + waitSupervisorStopped(t, "client", clientErr) + waitSupervisorStopped(t, "server", serverErr) +} + +func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config { + cfg := session.Config{ + Mode: mode, + Link: linkDirect, + Transport: transportData, + Auth: carrierName, + RoomID: testRoom, + KeyHex: testKeyHex, + DNSServer: localDNSServer, + } + if mode == "cnc" { + cfg.SOCKSHost = socksHost + cfg.SOCKSPort = socksPort + } + return cfg +} + +func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config { + return client.Config{ + Link: cfg.Link, + Transport: cfg.Transport, + Carrier: cfg.Auth, + RoomURL: cfg.RoomID, + KeyHex: cfg.KeyHex, + LocalAddr: socksAddr, + DNSServer: cfg.DNSServer, + DeviceID: testClientDeviceID, + VideoWidth: cfg.VideoWidth, + VideoHeight: cfg.VideoHeight, + VideoFPS: cfg.VideoFPS, + VideoBitrate: cfg.VideoBitrate, + VideoHW: cfg.VideoHW, + VideoQRSize: cfg.VideoQRSize, + VideoQRRecovery: cfg.VideoQRRecovery, + VideoCodec: cfg.VideoCodec, + VideoTileModule: cfg.VideoTileModule, + VideoTileRS: cfg.VideoTileRS, + VP8FPS: cfg.VP8FPS, + VP8BatchSize: cfg.VP8BatchSize, + SEIFPS: cfg.SEIFPS, + SEIBatchSize: cfg.SEIBatchSize, + SEIFragmentSize: cfg.SEIFragmentSize, + SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Engine: cfg.Engine, + URL: cfg.URL, + Token: cfg.Token, + } +} + +func failoverE2EConfig( + profiles []supervisor.Profile, + started chan<- string, + side string, +) supervisor.Config { + return supervisor.Config{ + Profiles: profiles, + RetryDelay: time.Millisecond, + OnProfileStart: func(profile supervisor.Profile, _ int) { + select { + case started <- side + ":" + profile.Name: + default: + } + }, + } +} + +func splitHostPort(t *testing.T, addr string) (string, int) { + t.Helper() + host, portText, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("split host port %q: %v", addr, err) + } + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatalf("parse port %q: %v", portText, err) + } + return host, port +} + +func requireStartedProfiles(t *testing.T, started <-chan string, want []string) { + t.Helper() + seen := make(map[string]bool) + deadline := time.After(3 * time.Second) + for len(seen) < len(want) { + select { + case item := <-started: + seen[item] = true + case <-deadline: + t.Fatalf("started profiles = %v, want all %v", seen, want) + } + } + for _, item := range want { + if !seen[item] { + t.Fatalf("started profiles = %v, missing %s", seen, item) + } + } +} + +func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) { + t.Helper() + select { + case err := <-ch: + if err != nil { + t.Fatalf("%s supervisor returned error: %v", name, err) + } + case <-time.After(3 * time.Second): + t.Fatalf("%s supervisor did not stop", name) + } +} + func TestEndedCallbackStopsClientAndServer(t *testing.T) { rt := startTunnel(t) rt.room.triggerEnded("conference ended") diff --git a/internal/engine/livekit/livekit.go b/internal/engine/livekit/livekit.go index 24c62bd..ad7e64d 100644 --- a/internal/engine/livekit/livekit.go +++ b/internal/engine/livekit/livekit.go @@ -19,13 +19,17 @@ import ( protoLogger "github.com/livekit/protocol/logger" lksdk "github.com/livekit/server-sdk-go/v2" "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/pion/webrtc/v4" ) const ( - defaultSendQueueSize = 5000 - dataPublishTopic = "olcrtc" - videoTrackName = "videochannel" + defaultSendQueueSize = 5000 + defaultSendQueueCapHard = 4000 + dataPublishTopic = "olcrtc" + videoTrackName = "videochannel" + reconnectWindow = 5 * time.Minute + maxReconnects = 10 ) var ( @@ -41,20 +45,98 @@ var ( ErrTokenRequired = errors.New("livekit access token required") ) +type roomHandle interface { + publishData([]byte) error + publishTrack(webrtc.TrackLocal) error + unpublishLocalTracks() + disconnect() + connectionState() lksdk.ConnectionState +} + +type sdkRoom struct { + room *lksdk.Room +} + +func (r *sdkRoom) publishData(data []byte) error { + return r.room.LocalParticipant.PublishDataPacket( + lksdk.UserData(data), + lksdk.WithDataPublishTopic(dataPublishTopic), + lksdk.WithDataPublishReliable(true), + ) +} + +func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error { + _, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName}) + return err +} + +func (r *sdkRoom) unpublishLocalTracks() { + if r.room == nil || r.room.LocalParticipant == nil { + return + } + for _, publication := range r.room.LocalParticipant.TrackPublications() { + if publication.SID() == "" { + continue + } + if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil { + log.Printf("livekit unpublish track error: %v", err) + } + } +} + +func (r *sdkRoom) disconnect() { + r.room.Disconnect() + // LiveKit's Disconnect returns after local SDK teardown, before the + // server necessarily evicts the participant. Give the signalling path a + // short grace period so immediate reconnects do not inherit stale room + // state from a ghost participant. + time.Sleep(2 * time.Second) +} + +func (r *sdkRoom) connectionState() lksdk.ConnectionState { + return r.room.ConnectionState() +} + +type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) + +func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) { + room, err := lksdk.ConnectToRoomWithToken( + url, + token, + callback, + lksdk.WithAutoSubscribe(true), + lksdk.WithLogger(protoLogger.GetDiscardLogger()), + ) + if err != nil { + return nil, err + } + return &sdkRoom{room: room}, nil +} + // Session is the LiveKit engine handle. type Session struct { url string token string name string - room *lksdk.Room + refresh func(ctx context.Context) (engine.Credentials, error) + connectRoom connectRoomFunc + room roomHandle + roomMu sync.RWMutex onData func([]byte) onReconnect func(*webrtc.DataChannel) shouldReconnect func() bool onEnded func(string) + reconnectCh chan struct{} + closeCh chan struct{} + lastReconnect time.Time + reconnectCount int sendQueue chan []byte closed atomic.Bool + reconnecting atomic.Bool done chan struct{} cancel context.CancelFunc + shutdownOnce sync.Once + sendWorkerOnce sync.Once videoTrackMu sync.RWMutex videoTracks []webrtc.TrackLocal onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver) @@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) { } _, cancel := context.WithCancel(ctx) return &Session{ - url: cfg.URL, - token: cfg.Token, - name: cfg.Name, - onData: cfg.OnData, - sendQueue: make(chan []byte, defaultSendQueueSize), - done: make(chan struct{}), - cancel: cancel, + url: cfg.URL, + token: cfg.Token, + name: cfg.Name, + refresh: cfg.Refresh, + connectRoom: connectSDKRoom, + onData: cfg.OnData, + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + cancel: cancel, }, nil } @@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities { } // Connect joins the LiveKit room. -func (s *Session) Connect(_ context.Context) error { +func (s *Session) Connect(ctx context.Context) error { + s.closed.Store(false) + if err := s.connectSession(ctx); err != nil { + return err + } + s.startSendWorker() + return nil +} + +func (s *Session) connectSession(_ context.Context) error { roomCB := &lksdk.RoomCallback{ ParticipantCallback: lksdk.ParticipantCallback{ OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) { @@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error { }, }, OnDisconnected: func() { - if !s.closed.Load() && s.onEnded != nil { - s.onEnded("disconnected from livekit") + if s.closed.Load() || s.reconnecting.Load() { + return + } + if !s.queueReconnect() { + s.signalEnded("disconnected from livekit") } }, } - room, err := lksdk.ConnectToRoomWithToken( - s.url, - s.token, - roomCB, - lksdk.WithAutoSubscribe(true), - lksdk.WithLogger(protoLogger.GetDiscardLogger()), - ) + room, err := s.connectRoom(s.url, s.token, roomCB) if err != nil { return fmt.Errorf("connect to room: %w", err) } - s.room = room + s.setRoom(room) if err := s.publishPendingTracks(); err != nil { return err } - s.wg.Add(1) - go s.processSendQueue() return nil } func (s *Session) publishPendingTracks() error { + room := s.currentRoom() + if room == nil { + return ErrRoomNotConnected + } s.videoTrackMu.RLock() defer s.videoTrackMu.RUnlock() for _, track := range s.videoTracks { - if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ - Name: videoTrackName, - }); err != nil { + if err := room.publishTrack(track); err != nil { return fmt.Errorf("failed to publish track: %w", err) } } return nil } +func (s *Session) startSendWorker() { + s.sendWorkerOnce.Do(func() { + s.wg.Add(1) + go s.processSendQueue() + }) +} + func (s *Session) processSendQueue() { defer s.wg.Done() for { @@ -157,17 +256,33 @@ func (s *Session) processSendQueue() { if !ok { return } - if err := s.room.LocalParticipant.PublishDataPacket( - lksdk.UserData(data), - lksdk.WithDataPublishTopic(dataPublishTopic), - lksdk.WithDataPublishReliable(true), - ); err != nil { + room := s.waitForConnectedRoom() + if room == nil { + return + } + if err := room.publishData(data); err != nil { log.Printf("livekit publish data error: %v", err) } } } } +func (s *Session) waitForConnectedRoom() roomHandle { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + room := s.currentRoom() + if room != nil && room.connectionState() == lksdk.ConnectionStateConnected { + return room + } + select { + case <-s.done: + return nil + case <-ticker.C: + } + } +} + // Send queues data for transmission. func (s *Session) Send(data []byte) error { if s.closed.Load() { @@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error { // Close terminates the session. func (s *Session) Close() error { - if s.closed.CompareAndSwap(false, true) { - s.cancel() - close(s.done) - if s.room != nil { - s.unpublishLocalTracks() - s.room.Disconnect() - // LiveKit's Disconnect() returns once the local SDK state - // is torn down, not when the server has actually evicted - // the participant. Without giving the signalling channel - // time to flush the LEAVE_REQUEST and the server to act on - // it, a back-to-back reconnect from the same identity in - // the same room sees a still-alive ghost participant on - // the SFU and inherits stale publication state. - time.Sleep(2 * time.Second) - } - close(s.sendQueue) - s.wg.Wait() - } + s.closed.Store(true) + s.shutdown() return nil } -func (s *Session) unpublishLocalTracks() { - if s.room == nil || s.room.LocalParticipant == nil { - return - } - for _, publication := range s.room.LocalParticipant.TrackPublications() { - if publication.SID() == "" { - continue +func (s *Session) shutdown() { + s.shutdownOnce.Do(func() { + if s.cancel != nil { + s.cancel() } - if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil { - log.Printf("livekit unpublish track error: %v", err) + closeSignal(s.closeCh) + closeSignal(s.done) + if room := s.swapRoom(nil); room != nil { + room.unpublishLocalTracks() + room.disconnect() } - } + s.wg.Wait() + }) } -// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity). +// SetReconnectCallback stores the reconnect callback. func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb } -// SetShouldReconnect stores the reconnect predicate (kept for API parity). +// SetShouldReconnect stores the reconnect predicate. func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn } // SetEndedCallback registers a function to call when the session ends. func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb } -// WatchConnection is a no-op; LiveKit handles connection supervision itself. -func (s *Session) WatchConnection(_ context.Context) {} +// WatchConnection monitors the connection lifecycle and reconnects as needed. +func (s *Session) WatchConnection(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-s.closeCh: + return + case <-s.reconnectCh: + if s.handleReconnectAttempt(ctx) { + return + } + } + } +} + +func (s *Session) handleReconnectAttempt(ctx context.Context) bool { + if time.Since(s.lastReconnect) > reconnectWindow { + s.reconnectCount = 0 + } + s.reconnectCount++ + s.lastReconnect = time.Now() + + if s.reconnectCount > maxReconnects { + s.signalEnded("reconnect limit reached") + return true + } + + backoff := time.Duration(s.reconnectCount) * 2 * time.Second + if backoff > 30*time.Second { + backoff = 30 * time.Second + } + + for { + if err := s.reconnect(ctx); err != nil { + logger.Debugf("livekit reconnect failed: %v", err) + select { + case <-ctx.Done(): + return true + case <-s.closeCh: + return true + case <-time.After(backoff): + continue + } + } + s.drainReconnectQueue() + return false + } +} + +func (s *Session) reconnect(ctx context.Context) error { + s.reconnecting.Store(true) + defer s.reconnecting.Store(false) + + if room := s.swapRoom(nil); room != nil { + room.unpublishLocalTracks() + room.disconnect() + } + + if s.refresh != nil { + creds, err := s.refresh(ctx) + if err != nil { + return fmt.Errorf("refresh credentials: %w", err) + } + s.applyRefreshedCredentials(creds) + } + + if err := s.connectSession(ctx); err != nil { + return err + } + if s.onReconnect != nil { + s.onReconnect(nil) + } + return nil +} + +func (s *Session) applyRefreshedCredentials(creds engine.Credentials) { + if creds.URL != "" { + s.url = creds.URL + } + if creds.Token != "" { + s.token = creds.Token + } +} + +func (s *Session) queueReconnect() bool { + if s.closed.Load() || s.reconnecting.Load() { + return false + } + if s.shouldReconnect != nil && !s.shouldReconnect() { + return false + } + select { + case s.reconnectCh <- struct{}{}: + default: + } + return true +} + +func (s *Session) drainReconnectQueue() { + for { + select { + case <-s.reconnectCh: + default: + return + } + } +} + +func (s *Session) signalEnded(reason string) { + s.closed.Store(true) + s.shutdown() + if s.onEnded != nil { + s.onEnded(reason) + } +} // CanSend reports whether the session is ready to accept data. -func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil } +func (s *Session) CanSend() bool { + if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard { + return false + } + room := s.currentRoom() + return room != nil && room.connectionState() == lksdk.ConnectionStateConnected +} // GetSendQueue exposes the outbound queue. func (s *Session) GetSendQueue() chan []byte { return s.sendQueue } @@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error { s.videoTracks = append(s.videoTracks, track) s.videoTrackMu.Unlock() - if s.room == nil || s.room.LocalParticipant == nil { + room := s.currentRoom() + if room == nil { return nil } - if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ - Name: videoTrackName, - }); err != nil { + if err := room.publishTrack(track); err != nil { return fmt.Errorf("failed to publish track: %w", err) } return nil @@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR s.onVideoTrack = cb } +func (s *Session) currentRoom() roomHandle { + s.roomMu.RLock() + defer s.roomMu.RUnlock() + return s.room +} + +func (s *Session) setRoom(room roomHandle) { + s.roomMu.Lock() + defer s.roomMu.Unlock() + s.room = room +} + +func (s *Session) swapRoom(room roomHandle) roomHandle { + s.roomMu.Lock() + defer s.roomMu.Unlock() + old := s.room + s.room = room + return old +} + +func closeSignal(ch chan struct{}) { + select { + case <-ch: + default: + close(ch) + } +} + func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins engine.Register("livekit", New) } diff --git a/internal/engine/livekit/livekit_test.go b/internal/engine/livekit/livekit_test.go new file mode 100644 index 0000000..7a46fd5 --- /dev/null +++ b/internal/engine/livekit/livekit_test.go @@ -0,0 +1,306 @@ +package livekit + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + lksdk "github.com/livekit/server-sdk-go/v2" + "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/pion/webrtc/v4" +) + +type fakeRoom struct { + mu sync.Mutex + state lksdk.ConnectionState + published [][]byte + tracks int + unpublished int + disconnected int +} + +func newFakeRoom() *fakeRoom { + return &fakeRoom{state: lksdk.ConnectionStateConnected} +} + +func (r *fakeRoom) publishData(data []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + r.published = append(r.published, append([]byte(nil), data...)) + return nil +} + +func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error { + r.mu.Lock() + defer r.mu.Unlock() + r.tracks++ + return nil +} + +func (r *fakeRoom) unpublishLocalTracks() { + r.mu.Lock() + defer r.mu.Unlock() + r.unpublished++ +} + +func (r *fakeRoom) disconnect() { + r.mu.Lock() + defer r.mu.Unlock() + r.disconnected++ + r.state = lksdk.ConnectionStateDisconnected +} + +func (r *fakeRoom) connectionState() lksdk.ConnectionState { + r.mu.Lock() + defer r.mu.Unlock() + return r.state +} + +type fakeConnector struct { + mu sync.Mutex + urls []string + tokens []string + callbacks []*lksdk.RoomCallback + rooms []*fakeRoom + connected chan struct{} + err error +} + +func newFakeConnector() *fakeConnector { + return &fakeConnector{connected: make(chan struct{}, 8)} +} + +func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + return nil, c.err + } + room := newFakeRoom() + c.urls = append(c.urls, url) + c.tokens = append(c.tokens, token) + c.callbacks = append(c.callbacks, cb) + c.rooms = append(c.rooms, room) + c.connected <- struct{}{} + return room, nil +} + +func (c *fakeConnector) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.rooms) +} + +func (c *fakeConnector) callback(i int) *lksdk.RoomCallback { + c.mu.Lock() + defer c.mu.Unlock() + return c.callbacks[i] +} + +func (c *fakeConnector) room(i int) *fakeRoom { + c.mu.Lock() + defer c.mu.Unlock() + return c.rooms[i] +} + +func (c *fakeConnector) snapshot() ([]string, []string) { + c.mu.Lock() + defer c.mu.Unlock() + return append([]string(nil), c.urls...), append([]string(nil), c.tokens...) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("condition was not met before timeout") +} + +func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + refreshes := 0 + sess, err := New(ctx, engine.Config{ + URL: "wss://old", + Token: "old-token", + Refresh: func(context.Context) (engine.Credentials, error) { + refreshes++ + return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil + }, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + s := sess.(*Session) + connector := newFakeConnector() + s.connectRoom = connector.connect + + reconnected := make(chan struct{}, 1) + s.SetReconnectCallback(func(*webrtc.DataChannel) { + reconnected <- struct{}{} + }) + + if err := s.Connect(ctx); err != nil { + t.Fatalf("Connect() error = %v", err) + } + go s.WatchConnection(ctx) + + connector.callback(0).OnDisconnected() + + waitFor(t, func() bool { return connector.count() == 2 }) + select { + case <-reconnected: + case <-time.After(time.Second): + t.Fatal("reconnect callback was not called") + } + + urls, tokens := connector.snapshot() + if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) { + t.Fatalf("connect urls = %v, want %v", got, want) + } + if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) { + t.Fatalf("connect tokens = %v, want %v", got, want) + } + if refreshes != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } + oldRoom := connector.room(0) + oldRoom.mu.Lock() + if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 { + t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1", + oldRoom.disconnected, oldRoom.unpublished) + } + oldRoom.mu.Unlock() + if !s.CanSend() { + t.Fatal("CanSend() = false after reconnect, want true") + } + + if err := s.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } +} + +func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) { + ctx := context.Background() + sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + s := sess.(*Session) + connector := newFakeConnector() + s.connectRoom = connector.connect + s.SetShouldReconnect(func() bool { return false }) + + ended := make(chan string, 1) + s.SetEndedCallback(func(reason string) { + ended <- reason + }) + + if err := s.Connect(ctx); err != nil { + t.Fatalf("Connect() error = %v", err) + } + connector.callback(0).OnDisconnected() + + select { + case reason := <-ended: + if reason != "disconnected from livekit" { + t.Fatalf("ended reason = %q, want disconnected from livekit", reason) + } + case <-time.After(time.Second): + t.Fatal("ended callback was not called") + } + if !s.closed.Load() { + t.Fatal("closed = false after terminal disconnect") + } + if connector.count() != 1 { + t.Fatalf("connect count = %d, want 1", connector.count()) + } + room := connector.room(0) + room.mu.Lock() + if room.disconnected != 1 || room.unpublished != 1 { + t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1", + room.disconnected, room.unpublished) + } + room.mu.Unlock() + + if err := s.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + room.mu.Lock() + if room.disconnected != 1 || room.unpublished != 1 { + t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1", + room.disconnected, room.unpublished) + } + room.mu.Unlock() +} + +func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) { + s := &Session{ + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + closeCh: make(chan struct{}), + } + if s.CanSend() { + t.Fatal("CanSend() = true without room") + } + + room := newFakeRoom() + room.state = lksdk.ConnectionStateDisconnected + s.setRoom(room) + if s.CanSend() { + t.Fatal("CanSend() = true for disconnected room") + } + + room.state = lksdk.ConnectionStateConnected + if !s.CanSend() { + t.Fatal("CanSend() = false for connected room") + } + + for i := 0; i < defaultSendQueueCapHard; i++ { + s.sendQueue <- []byte("x") + } + if s.CanSend() { + t.Fatal("CanSend() = true at queue high watermark") + } +} + +func TestReconnectFailureRetriesUntilContextDone(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := &Session{ + url: "wss://old", + token: "old-token", + connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) { + cancel() + return nil, errors.New("boom") + }, + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, defaultSendQueueSize), + done: make(chan struct{}), + } + if terminal := s.handleReconnectAttempt(ctx); !terminal { + t.Fatal("handleReconnectAttempt() = false after context cancellation") + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/supervisor/supervisor.go b/internal/supervisor/supervisor.go new file mode 100644 index 0000000..929fed6 --- /dev/null +++ b/internal/supervisor/supervisor.go @@ -0,0 +1,96 @@ +// Package supervisor runs ordered session profiles with failover. +package supervisor + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" +) + +const DefaultRetryDelay = 2 * time.Second + +var ( + // ErrNoProfiles is returned when the supervisor is started without profiles. + ErrNoProfiles = errors.New("supervisor: no profiles configured") + // ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes. + ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded") +) + +// Profile is one runnable session configuration in an ordered failover list. +type Profile struct { + Name string + Config session.Config +} + +// Runner starts one session profile and blocks until it ends or fails. +type Runner func(ctx context.Context, cfg session.Config) error + +// Config controls ordered failover behavior. +type Config struct { + Profiles []Profile + RetryDelay time.Duration + MaxCycles int + + OnProfileStart func(profile Profile, cycle int) + OnProfileEnd func(profile Profile, cycle int, err error) +} + +// Run starts profiles in order. If a profile exits while ctx is still active, +// the supervisor waits RetryDelay and advances to the next profile. +func Run(ctx context.Context, cfg Config, run Runner) error { + if len(cfg.Profiles) == 0 { + return ErrNoProfiles + } + if cfg.RetryDelay == 0 { + cfg.RetryDelay = DefaultRetryDelay + } + + var lastErr error + for cycle := 1; ; cycle++ { + for i, profile := range cfg.Profiles { + if ctx.Err() != nil { + return nil + } + if cfg.OnProfileStart != nil { + cfg.OnProfileStart(profile, cycle) + } + + err := run(ctx, profile.Config) + if ctx.Err() != nil { + return nil + } + if err != nil { + lastErr = fmt.Errorf("profile %q: %w", profile.Name, err) + } else { + lastErr = fmt.Errorf("profile %q ended", profile.Name) + } + if cfg.OnProfileEnd != nil { + cfg.OnProfileEnd(profile, cycle, err) + } + + if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 { + return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr) + } + if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil { + return nil + } + } + } +} + +func waitRetryDelay(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/supervisor/supervisor_test.go b/internal/supervisor/supervisor_test.go new file mode 100644 index 0000000..aab0dee --- /dev/null +++ b/internal/supervisor/supervisor_test.go @@ -0,0 +1,85 @@ +package supervisor + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" +) + +var errRunnerBoom = errors.New("boom") + +func TestRunRequiresProfiles(t *testing.T) { + err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil }) + if !errors.Is(err, ErrNoProfiles) { + t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles) + } +} + +func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) { + profiles := []Profile{ + {Name: "first", Config: session.Config{Auth: "wbstream"}}, + {Name: "second", Config: session.Config{Auth: "jitsi"}}, + } + var started []string + var ended []string + err := Run(context.Background(), Config{ + Profiles: profiles, + RetryDelay: -1, + MaxCycles: 1, + OnProfileStart: func(profile Profile, cycle int) { + started = append(started, profile.Name) + if cycle != 1 { + t.Fatalf("cycle = %d, want 1", cycle) + } + }, + OnProfileEnd: func(profile Profile, _ int, err error) { + ended = append(ended, profile.Name) + if !errors.Is(err, errRunnerBoom) { + t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom) + } + }, + }, func(_ context.Context, cfg session.Config) error { + if cfg.Auth == "" { + t.Fatal("runner received empty auth") + } + return errRunnerBoom + }) + if !errors.Is(err, ErrMaxCyclesExceeded) { + t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded) + } + if got, want := started, []string{"first", "second"}; !equalStrings(got, want) { + t.Fatalf("started = %v, want %v", got, want) + } + if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) { + t.Fatalf("ended = %v, want %v", got, want) + } +} + +func TestRunReturnsNilOnContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + err := Run(ctx, Config{ + Profiles: []Profile{{Name: "one"}}, + RetryDelay: time.Hour, + }, func(context.Context, session.Config) error { + cancel() + return nil + }) + if err != nil { + t.Fatalf("Run() error = %v, want nil", err) + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +}