Merge pull request #58 from cyber-debug/refine/livekit-reconnect

refine livekit reconnect and liveness
This commit is contained in:
zarazaex
2026-05-16 03:47:47 +03:00
committed by GitHub
46 changed files with 4785 additions and 294 deletions

1
.gitignore vendored
View File

@@ -1,5 +1,6 @@
# Prerequisites
*.d
.DS_Store
# Object files
*.o

View File

@@ -24,6 +24,7 @@ import (
configpkg "github.com/openlibrecommunity/olcrtc/internal/config"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
"github.com/openlibrecommunity/olcrtc/internal/transport/videochannel"
)
@@ -35,6 +36,9 @@ var ErrConfigPathRequired = errors.New("usage: olcrtc <config.yaml>")
// ErrDataDirRequired is returned when the YAML config does not specify a data directory.
var ErrDataDirRequired = errors.New("data directory required (set 'data:' in YAML)")
// ErrProfilesUnsupportedForGen is returned when failover profiles are configured for gen mode.
var ErrProfilesUnsupportedForGen = errors.New("profiles are only supported for srv and cnc modes")
//nolint:gochecknoglobals // Tests replace the long-running session runner with a bounded function.
var runSession = session.Run
@@ -44,11 +48,18 @@ var runGen = execGen
// loadedConfig bundles the parsed YAML file and the derived session config.
type loadedConfig struct {
scfg session.Config
profiles []supervisor.Profile
failover failoverConfig
dataDir string
debug bool
ffmpegPath string
}
type failoverConfig struct {
retryDelay time.Duration
maxCycles int
}
func main() {
if err := run(); err != nil {
logger.Error(err)
@@ -79,14 +90,44 @@ func loadConfig(path string) (loadedConfig, error) {
if err != nil {
return loadedConfig{}, fmt.Errorf("load config: %w", err)
}
base := configpkg.Apply(session.Config{}, f)
profiles := make([]supervisor.Profile, 0, len(f.Profiles))
for i, profile := range f.Profiles {
name := profile.Name
if name == "" {
name = fmt.Sprintf("profile-%d", i+1)
}
profiles = append(profiles, supervisor.Profile{
Name: name,
Config: configpkg.ApplyProfile(base, profile),
})
}
failover, err := parseFailoverConfig(f.Failover)
if err != nil {
return loadedConfig{}, err
}
return loadedConfig{
scfg: configpkg.Apply(session.Config{}, f),
scfg: base,
profiles: profiles,
failover: failover,
dataDir: f.Data,
debug: f.Debug,
ffmpegPath: f.FFmpeg,
}, nil
}
func parseFailoverConfig(f configpkg.Failover) (failoverConfig, error) {
retryDelay := supervisor.DefaultRetryDelay
if f.RetryDelay != "" {
parsed, err := time.ParseDuration(f.RetryDelay)
if err != nil {
return failoverConfig{}, fmt.Errorf("parse failover.retry_delay: %w", err)
}
retryDelay = parsed
}
return failoverConfig{retryDelay: retryDelay, maxCycles: f.MaxCycles}, nil
}
func runWithConfig(cfg loadedConfig) error {
configureLogging(cfg.debug)
@@ -98,19 +139,116 @@ func runWithConfig(cfg loadedConfig) error {
if err != nil {
return fmt.Errorf("validate config: %w", err)
}
scfg = session.ApplyTransportDefaults(scfg)
scfg = session.ApplyLivenessDefaults(scfg)
if scfg.Mode == modeGen {
if len(cfg.profiles) > 0 {
return ErrProfilesUnsupportedForGen
}
return runGen(scfg)
}
if len(cfg.profiles) > 0 {
profiles, err := prepareProfiles(cfg.profiles)
if err != nil {
return err
}
return runFailoverSessionMode(cfg.dataDir, profiles, cfg.failover)
}
return runSessionMode(cfg.dataDir, scfg)
}
func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error) {
out := make([]supervisor.Profile, 0, len(profiles))
for _, profile := range profiles {
scfg, err := session.ApplyAuthDefaults(profile.Config)
if err != nil {
return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err)
}
profile.Config = session.ApplyLivenessDefaults(session.ApplyTransportDefaults(scfg))
out = append(out, profile)
}
return out, nil
}
func runSessionMode(dataDir string, scfg session.Config) error {
if err := session.Validate(scfg); err != nil {
return fmt.Errorf("validate config: %w", err)
}
if err := prepareRuntimeData(dataDir); err != nil {
return err
}
return runManaged(func(ctx context.Context) error {
return runSession(ctx, scfg)
})
}
func runFailoverSessionMode(dataDir string, profiles []supervisor.Profile, failover failoverConfig) error {
for _, profile := range profiles {
if err := session.Validate(profile.Config); err != nil {
return fmt.Errorf("validate profile %q: %w", profile.Name, err)
}
}
if err := prepareRuntimeData(dataDir); err != nil {
return err
}
return runManaged(func(ctx context.Context) error {
return supervisor.Run(ctx, supervisor.Config{
Profiles: profiles,
RetryDelay: failover.retryDelay,
MaxCycles: failover.maxCycles,
OnProfileStart: func(profile supervisor.Profile, cycle int) {
logger.Infof("failover cycle=%d starting profile=%s carrier=%s transport=%s",
cycle, profile.Name, profile.Config.Auth, profile.Config.Transport)
},
OnProfileEnd: func(profile supervisor.Profile, cycle int, err error) {
if err != nil {
logger.Warnf("failover cycle=%d profile=%s ended with error: %v", cycle, profile.Name, err)
return
}
logger.Warnf("failover cycle=%d profile=%s ended", cycle, profile.Name)
},
OnStatus: logFailoverStatus,
}, runSession)
})
}
func logFailoverStatus(status supervisor.Status) {
if !logger.IsVerbose() {
return
}
active := status.ActiveProfile
if active == "" {
active = "none"
}
logger.Debugf("failover status cycle=%d active=%s last_error=%q profiles=%s history=%d",
status.Cycle, active, status.LastError, formatProfileStatuses(status.Profiles), len(status.History))
}
func formatProfileStatuses(profiles []supervisor.ProfileStatus) string {
if len(profiles) == 0 {
return "[]"
}
var buf bytes.Buffer
buf.WriteByte('[')
for i, profile := range profiles {
if i > 0 {
buf.WriteByte(' ')
}
fmt.Fprintf(&buf, "%s{starts=%d failures=%d clean=%d}",
profile.Name, profile.Starts, profile.Failures, profile.CleanEnds)
}
buf.WriteByte(']')
return buf.String()
}
func prepareRuntimeData(dataDir string) error {
if dataDir == "" {
return ErrDataDirRequired
}
@@ -124,6 +262,10 @@ func runSessionMode(dataDir string, scfg session.Config) error {
return err
}
return nil
}
func runManaged(run func(context.Context) error) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -132,7 +274,7 @@ func runSessionMode(dataDir string, scfg session.Config) error {
errCh := make(chan error, 1)
go func() {
errCh <- runSession(ctx, scfg)
errCh <- run(ctx)
}()
select {

View File

@@ -9,6 +9,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
)
var errBoom = errors.New("boom")
@@ -149,6 +150,112 @@ data: `+dir+`
}
}
func TestRunWithArgsAppliesTransportDefaults(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil {
t.Fatalf("WriteFile(names) error = %v", err)
}
if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil {
t.Fatalf("WriteFile(surnames) error = %v", err)
}
oldRunSession := runSession
t.Cleanup(func() { runSession = oldRunSession })
runSession = func(ctx context.Context, cfg session.Config) error {
if cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1 {
t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize)
}
return nil
}
yamlPath := writeYAML(t, `
mode: srv
link: direct
auth:
provider: wbstream
room:
id: room
crypto:
key: key
net:
transport: vp8channel
dns: 1.1.1.1:53
data: `+dir+`
`)
if err := runWithArgs([]string{yamlPath}); err != nil {
t.Fatalf("runWithArgs() error = %v", err)
}
}
func TestRunWithArgsFailoverProfiles(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil {
t.Fatalf("WriteFile(names) error = %v", err)
}
if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil {
t.Fatalf("WriteFile(surnames) error = %v", err)
}
oldRunSession := runSession
t.Cleanup(func() { runSession = oldRunSession })
var seen []string
runSession = func(ctx context.Context, cfg session.Config) error {
seen = append(seen, cfg.Auth+"/"+cfg.Transport)
if cfg.Auth == "wbstream" && (cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1) {
t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize)
}
return errBoom
}
yamlPath := writeYAML(t, `
mode: srv
link: direct
crypto:
key: key
net:
dns: 1.1.1.1:53
profiles:
- name: wb-primary
auth:
provider: wbstream
room:
id: room
net:
transport: vp8channel
- name: jitsi-backup
auth:
provider: jitsi
room:
id: https://meet.example/room
net:
transport: datachannel
failover:
retry_delay: -1ns
max_cycles: 1
data: `+dir+`
`)
err := runWithArgs([]string{yamlPath})
if !errors.Is(err, supervisor.ErrMaxCyclesExceeded) {
t.Fatalf("runWithArgs() error = %v, want %v", err, supervisor.ErrMaxCyclesExceeded)
}
want := []string{"wbstream/vp8channel", "jitsi/datachannel"}
if !equalStrings(seen, want) {
t.Fatalf("seen profiles = %v, want %v", seen, want)
}
}
func TestRunWithConfigRejectsProfilesInGenMode(t *testing.T) {
cfg := loadedConfig{
scfg: session.Config{Mode: modeGen},
profiles: []supervisor.Profile{{Name: "one"}},
}
if err := runWithConfig(cfg); !errors.Is(err, ErrProfilesUnsupportedForGen) {
t.Fatalf("runWithConfig() error = %v, want %v", err, ErrProfilesUnsupportedForGen)
}
}
func TestConfigureLogging(t *testing.T) {
t.Setenv("PION_LOG_DISABLE", "")
logger.SetVerbose(false)
@@ -170,6 +277,18 @@ func TestConfigureLogging(t *testing.T) {
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestResolveDataDir(t *testing.T) {
abs := filepath.Join(t.TempDir(), "data")
got, err := resolveDataDir(abs)

View File

@@ -234,7 +234,7 @@ internal/e2e/ E2E тесты на реальных провайдер
| Файл | Что делает |
|---|---|
| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для jazz с ретраями (wbstream больше не поддерживает автогенерацию - руму нужно создавать вручную через stream.wb.ru) |
| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для auth-провайдеров с `RoomCreator` и ретраями |
| `session_test.go` | Тесты валидации конфига |
### `internal/config/`
@@ -452,7 +452,7 @@ Carrier - это WebRTC сервис видеозвонков, через кот
- Минимальная прослойка, почти прямой relay
- Работает с vp8channel, seichannel, videochannel
- DataChannel **не работает** в обычном guest flow: WB Stream выдаёт токены с `canPublishData=false`, DC не маршрутизирует данные (expected fail в E2E тестах)
- Room ID нужно создавать вручную через stream.wb.ru
- Room ID можно создать вручную через stream.wb.ru или через `mode: gen`
- Инициализация звонка автоматически
---

View File

@@ -14,12 +14,28 @@ room:
id: "https://meet.cryptopro.ru/REPLACE_WITH_ROOM_NAME"
crypto:
# Or use key_file: "./olcrtc.key" to keep the secret out of this file.
key: "REPLACE_ME_WITH_64_HEX_CHARS" # must match the server
net:
transport: datachannel # must match the server
dns: "8.8.8.8:53"
liveness:
interval: 10s
timeout: 5s
failures: 3
# Optional planned rebuild for long-running calls.
# lifecycle:
# max_session_duration: 6h
# Optional reliability shaping for encrypted wire messages.
# traffic:
# max_payload_size: 4096
# min_delay: 5ms
# max_delay: 30ms
# Local SOCKS5 listener exposed to applications
socks:
host: "127.0.0.1"

View File

@@ -11,6 +11,7 @@ olcrtc /etc/olcrtc/server.yaml
- [`server.example.yaml`](./server.example.yaml)
- [`client.example.yaml`](./client.example.yaml)
- [`failover.example.yaml`](./failover.example.yaml)
## Схема
@@ -20,7 +21,7 @@ olcrtc /etc/olcrtc/server.yaml
| `link` | `direct` |
| `auth.provider` | `jitsi`, `telemost`, `jazz`, `wbstream`, `none` |
| `room.id` | conference room id |
| `crypto.key` | 64-char hex (32 bytes) |
| `crypto.key` / `crypto.key_file` | 64-char hex (32 bytes), inline or read from file |
| `net.transport` | `datachannel`, `videochannel`, `seichannel`, `vp8channel` |
| `net.dns` | resolver `host:port` |
| `socks.host` / `.port` | client-side listener |
@@ -30,7 +31,126 @@ olcrtc /etc/olcrtc/server.yaml
| `video.*` | videochannel tuning |
| `vp8.*` | vp8channel tuning |
| `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning |
| `liveness.interval` | control-stream ping interval, default `10s` |
| `liveness.timeout` | pong timeout, default `5s` |
| `liveness.failures` | missed pongs before reconnect, default `3` |
| `lifecycle.max_session_duration` | planned session rebuild interval, e.g. `6h`; unset = off |
| `traffic.max_payload_size` | safe encrypted wire-message cap; `0` = transport default |
| `traffic.min_delay` / `.max_delay` | optional send pacing jitter, e.g. `5ms` / `30ms` |
| `gen.amount` | gen mode: number of rooms to create |
| `profiles[]` | ordered srv/cnc failover profiles |
| `failover.retry_delay` | delay before trying the next profile, e.g. `2s` |
| `failover.max_cycles` | stop after N full profile-list passes; `0` = forever |
| `data` | path to data directory |
| `debug` | verbose logging |
| `ffmpeg` | path to ffmpeg binary |
`mode: cnc` refuses non-loopback `socks.host` values unless both
`socks.user` and `socks.pass` are set.
`crypto.key_file` is resolved relative to the YAML file. Do not set it
together with `crypto.key`.
## Liveness
After `CLIENT_HELLO` / `SERVER_WELCOME`, the first smux stream stays open as
an encrypted control stream. olcrtc now sends `CONTROL_PING` / `CONTROL_PONG`
messages over that stream to prove the real tunnel path still round-trips.
This detects states where a provider or WebRTC layer looks connected but the
encrypted smux path is no longer usable.
```yaml
liveness:
interval: 10s
timeout: 5s
failures: 3
```
When the failure threshold is reached, the current smux session is rebuilt.
In failover mode, a profile that exits after liveness-triggered reconnect
failure lets the supervisor advance to the next profile.
## Lifecycle Rotation
`lifecycle.max_session_duration` sets a planned upper bound for one provider
call/session. When the duration expires, olcrtc cancels the active server or
client session and starts a fresh one with the same config. While this option
is enabled, clean session endings are also restarted so the peer that did not
fire the timer can follow the rebuild. This is useful for long-running
deployments where provider calls get stale, accumulate media state, or should
be periodically re-created.
```yaml
lifecycle:
max_session_duration: 6h
```
The field is optional and disabled when omitted. Values use Go duration syntax
such as `30m`, `2h`, or `6h`; zero and negative durations are rejected.
## Traffic Shaping
`traffic` applies a shared reliability-oriented wrapper around the selected
transport. It can cap encrypted wire-message size and add small send pacing
delays without truncating data. When a payload would exceed the effective cap,
the send fails clearly instead of cutting bytes and corrupting smux.
```yaml
traffic:
max_payload_size: 4096
min_delay: 5ms
max_delay: 30ms
```
The wrapper clamps the configured payload cap to the selected transport's
advertised `MaxPayloadSize`. Client and server also reduce smux frame size to
fit the effective encrypted payload cap, accounting for crypto overhead. `0`
adds no extra cap beyond the selected transport's advertised limit. Delays use
Go duration syntax; if only `min_delay` is set, it is a fixed delay. Use the
same traffic settings on both peers.
## Failover Profiles
`mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used
as common defaults; each profile overrides only the fields it sets. The CLI
runs profiles in order. If a profile fails or ends while the process is still
alive, olcrtc waits `failover.retry_delay` and starts the next profile.
```yaml
mode: srv
link: direct
crypto:
key_file: ./olcrtc.key
net:
dns: "1.1.1.1:53"
data: data
profiles:
- name: wb-vp8
auth:
provider: wbstream
room:
id: "WB_ROOM_ID"
net:
transport: vp8channel
- name: jitsi-dc
auth:
provider: jitsi
room:
id: "https://meet.example.org/olcrtc-room"
net:
transport: datachannel
failover:
retry_delay: 2s
max_cycles: 0
```
Both peers must use compatible profile order and room settings. This first
failover layer rebuilds the session on the next profile; active smux streams
do not migrate, but new connections can recover on the next profile.
When `debug: true` is enabled, the CLI also emits a compact supervisor status
snapshot with the active profile, per-profile start/failure counters, and
bounded failover history size.

View File

@@ -0,0 +1,49 @@
# olcrtc failover config example
# Use the same profile order on both peers.
mode: srv
link: direct
crypto:
key_file: "./olcrtc.key"
net:
dns: "1.1.1.1:53"
liveness:
interval: 10s
timeout: 5s
failures: 3
# Optional planned rebuild for each active profile.
# lifecycle:
# max_session_duration: 6h
# Optional reliability shaping for encrypted wire messages.
# traffic:
# max_payload_size: 4096
# min_delay: 5ms
# max_delay: 30ms
data: data
profiles:
- name: wb-vp8
auth:
provider: wbstream
room:
id: "REPLACE_WITH_WB_ROOM_ID"
net:
transport: vp8channel
- name: jitsi-datachannel
auth:
provider: jitsi
room:
id: "https://meet.example.org/REPLACE_WITH_ROOM_NAME"
net:
transport: datachannel
failover:
retry_delay: 2s
max_cycles: 0

View File

@@ -177,7 +177,7 @@ data: data
### wbstream + vp8channel (альтернатива)
Сначала создай руму вручную через сайт [wbstream](https://stream.wb.ru) (автогенерация через `mode: gen` для wbstream больше не поддерживается) и сохрани её ID.
Создай руму через сайт [wbstream](https://stream.wb.ru) или заранее сгенерируй ID через `mode: gen` с `auth.provider: wbstream`.
`wbstream + datachannel` **не работает** в обычном guest flow — WB Stream выдаёт токены с `canPublishData=false`, и DC не маршрутизирует данные. Для обычного использования выбирай `vp8channel`.

425
docs/project-map.md Normal file
View File

@@ -0,0 +1,425 @@
# olcRTC Project Map
This is a developer map for finding the useful parts of the project quickly.
It focuses on code ownership, runtime flow, extension points, and areas that
are worth deeper work.
## One-Sentence Model
olcRTC is an encrypted TCP-over-WebRTC tunnel: the client exposes a local
SOCKS5 listener, the server dials requested TCP targets, and both sides carry
the smux byte stream through a selected WebRTC carrier and transport.
## Runtime Stack
```text
YAML config
-> cmd/olcrtc
-> internal/config
-> internal/app/session
-> internal/server or internal/client
-> internal/link/direct
-> internal/transport/{datachannel,vp8channel,seichannel,videochannel}
-> internal/carrier/builtin
-> internal/auth/<provider> + internal/engine/<engine>
-> external service SFU / signaling
```
Tunnel data path:
```text
local app
-> client SOCKS5
-> smux stream
-> muxconn AEAD encrypt
-> link.Send
-> transport encoding
-> carrier/engine
-> SFU/service
-> peer engine/carrier
-> transport decoding
-> muxconn AEAD decrypt
-> smux stream
-> server TCP dial
-> target host
```
## Entrypoints
| Path | Purpose |
|---|---|
| `cmd/olcrtc/main.go` | Main CLI. Accepts one YAML file, applies auth and transport defaults, starts `srv`, `cnc`, or `gen`. |
| `cmd/olcrtc-cgo/main.go` | Small c-shared entrypoint for desktop/native consumers. |
| `pkg/olcrtc` | Embeddable lower-level API that returns a `net.Conn`-like handle over an engine data path. |
| `pkg/olcrtc/tunnel` | Embeddable server-side tunnel API with auth and traffic hooks. |
| `mobile/mobile.go` | gomobile API for Android clients, including VPN socket protection. |
| `script/srv.sh`, `script/cnc.sh` | Interactive shell launchers that generate YAML and run/build the app. |
| `Dockerfile`, `script/docker/*` | Container build and server entrypoint/healthcheck. |
## Config And Session Layer
`internal/config` owns YAML parsing and file-backed secret loading.
Important fields:
| YAML | Runtime field | Notes |
|---|---|---|
| `mode` | `session.Config.Mode` | `srv`, `cnc`, or `gen`. |
| `auth.provider` | `Auth` | `jitsi`, `telemost`, `jazz`, `wbstream`, or `none`. |
| `room.id` | `RoomID` | Carrier-specific room reference. |
| `crypto.key` / `crypto.key_file` | `KeyHex` | Shared 32-byte key encoded as 64 hex chars. |
| `net.transport` | `Transport` | `datachannel`, `vp8channel`, `seichannel`, or `videochannel`. |
| `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. |
| `socks.*` | SOCKS fields | Client listener and optional server egress proxy. |
| `engine.*` | direct engine fields | Used only with `auth.provider: none`. |
| `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. |
| `lifecycle.*` | session lifecycle | Planned call/session rotation. |
| `traffic.*` | send shaping | Encrypted wire-message size cap and optional pacing jitter. |
`internal/app/session` is the main router:
1. Registers built-ins via `RegisterDefaults`.
2. Applies auth defaults: auth provider decides engine and default service URL.
3. Applies transport defaults: documented defaults for `vp8`, `sei`, and `video`.
4. Validates mode, auth, link, transport, room, key, DNS, transport options, and SOCKS listener safety.
5. Runs `server.Run`, `client.Run`, or `Gen`.
## Server Side
`internal/server` accepts encrypted smux sessions from the peer and proxies
each smux stream to a TCP target.
Core pieces:
| Symbol | Role |
|---|---|
| `server.Run` | Creates cipher, link, smux server, and serve loop. |
| `bringUpLink` | Builds `link.Link`, wires reconnect callbacks, connects carrier. |
| `installSession` / `reinstallSession` | Creates or replaces `muxconn + smux.Session`. |
| `acceptHandshake` | First smux stream; runs `handshake.Server`. |
| `handleStream` | Reads connect JSON and dispatches a tunnel stream. |
| `dispatch` | Dials target, sends ready byte, copies both directions. |
| `AuthHook` | Embedders can authorize clients after `CLIENT_HELLO`. |
| `OnSessionOpen`, `OnSessionClose`, `OnTraffic` | Observability hooks. |
Server risk areas:
- Target dialing is powerful by design. Any real product wrapper should add
an `AuthHook` and probably destination policy.
- `defaultAuthHook` admits everyone who knows the room and key.
- Reconnect rebuilds smux sessions; active streams are sacrificed.
## Client Side
`internal/client` exposes a local SOCKS5 listener and opens one smux stream
per SOCKS CONNECT request.
Core pieces:
| Symbol | Role |
|---|---|
| `RunWithReady` | Starts link, opens smux client, listens on local SOCKS. |
| `openControlStream` | First smux stream; runs `handshake.Client`. |
| `handleSocks5` | SOCKS method negotiation and CONNECT parsing. |
| `sendConnectRequest` | Sends server-side target JSON and waits for ready byte. |
| `handleReconnect` | Rebuilds smux and control stream after carrier reconnect. |
| `resolveDeviceID` | Optional persistent client identity for hooks. |
Client risk areas:
- A non-loopback SOCKS listener must require `socks.user` and `socks.pass`.
- SOCKS credentials are simple static credentials, not a full account system.
- Existing streams do not survive reconnect; new SOCKS connections can recover.
## Wire Protocol Above WebRTC
`internal/muxconn` adapts `link.Link` to `io.ReadWriteCloser`.
- Every smux write is encrypted with `internal/crypto`.
- Every inbound link message is decrypted and appended to an internal byte buffer.
- Bad AEAD frames are dropped.
- `CanSend` provides backpressure before encrypting and sending.
`internal/crypto` uses XChaCha20-Poly1305 with a random nonce prepended to
each ciphertext.
`internal/handshake` runs on the first smux stream:
```text
CLIENT_HELLO { version, device_id, claims }
SERVER_WELCOME { version, session_id }
or
SERVER_REJECT { version, reason }
```
The handshake has a 64 KiB frame cap and a default 15 second timeout.
After handshake, `internal/control` keeps that same encrypted smux stream open
and exchanges length-prefixed JSON control messages:
```text
CONTROL_PING { version, seq, sent_unix_nano }
CONTROL_PONG { version, seq, sent_unix_nano }
```
Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and
`liveness.failures: 3`. Missed pongs mark the smux session unhealthy and
trigger a session rebuild/reconnect path.
Client and server runtimes also maintain a `control.Status` snapshot with
session ID, last pong time, RTT, missed pongs, reconnect count, and unhealthy
event count. Embedders can consume it through the client/server health
callbacks.
## Registries And Plugin Shape
The universal-carrier refactor centers on small registries:
| Registry | Package | Registers |
|---|---|---|
| Auth providers | `internal/auth` | Service-specific credential and room creation flows. |
| Engines | `internal/engine` | Wire-level SFU protocol implementations. |
| Carriers | `internal/carrier` | Auth + engine adapters exposed as byte/video capability providers. |
| Transports | `internal/transport` | Byte transport strategy over carrier primitives. |
| Links | `internal/link` | Higher-level link abstraction; currently only `direct`. |
`internal/carrier/builtin` connects the auth and engine worlds:
```text
carrier "wbstream" -> auth/wbstream -> engine/livekit
carrier "jazz" -> auth/salutejazz -> engine/salutejazz
carrier "telemost"-> auth/telemost -> engine/goolom
carrier "jitsi" -> auth/jitsi -> engine/jitsi
carrier "none" -> direct user-supplied engine/url/token
```
## Auth Providers
| Provider | Engine | Room generation | Notes |
|---|---|---:|---|
| `jitsi` | `jitsi` | No | Parses host/room from a public or self-hosted Jitsi URL. No HTTP auth. |
| `telemost` | `goolom` | No | Calls Telemost room-info flow and returns Goolom credentials. |
| `wbstream` | `livekit` | Yes | Registers guest, optionally creates room, joins room, fetches LiveKit token. |
| `jazz` / `salutejazz` | `salutejazz` | Yes | Creates or joins SaluteJazz room and returns room/password tuple. |
| `none` | chosen by config | No | Direct engine mode for downstream tools or self-hosted SFUs. |
## Engines
Engines expose the low-level service/SFU protocol.
| Engine | Package | Byte stream | Video track | Main job |
|---|---|---:|---:|---|
| `livekit` | `internal/engine/livekit` | Yes | Yes | LiveKit SDK room, data packets, local/remote tracks, reconnect with credential refresh. |
| `goolom` | `internal/engine/goolom` | Yes | Yes | Yandex Telemost/Goolom signaling, split publisher/subscriber peer connections, telemetry/keepalive. |
| `jitsi` | `internal/engine/jitsi` | Yes | Best effort | Jitsi MUC/Jingle/colibri-ws plus optional video track negotiation. |
| `salutejazz` | `internal/engine/salutejazz` | Yes | Yes | SaluteJazz WebSocket signaling and split media peer connections. |
Engine work is where most provider breakage and reconnect complexity lives.
## Transports
Transports decide how raw tunnel bytes are carried once the carrier provides
either a byte stream or a video track.
| Transport | Primitive | Reliability model | Best fit | Notes |
|---|---|---|---|---|
| `datachannel` | Carrier byte stream | Native reliable ordered messages | Jitsi, direct engines, some Jazz cases | Simple pass-through with 12 KiB message cap. |
| `vp8channel` | VP8 video track | KCP over VP8-looking frames | WB Stream and Telemost-style video paths | Highest-performance video-path transport. Uses epochs and binding tokens to survive restarts/loopback. |
| `seichannel` | H264 SEI video track | Custom fragments + ACK/retry | WB Stream fallback | Carries data in SEI NAL units with fragmentation, CRC, ACK. |
| `videochannel` | Visual frames via ffmpeg | QR/tile frames + ACK/retry | Experimental/inspection-friendly path | Encodes visual payload frames, requires ffmpeg, supports QR and tile codecs. |
Transport work is where throughput, loss recovery, and adaptive tuning should
happen.
## Public/Embedding Surfaces
| Package | User |
|---|---|
| `pkg/olcrtc` | Go programs that want a `net.Conn` over a selected auth/engine. |
| `pkg/olcrtc/tunnel` | Go programs that want to embed the server-side tunnel with auth/traffic hooks. |
| `mobile` | Android app bindings. Wraps client mode, VPN socket protection, logging, simple health checks. |
| `cmd/olcrtc-cgo` | Native desktop/client integrations using c-shared Go export. |
These surfaces are important if the CLI becomes only one frontend among many.
## Tests
The project has broad unit coverage:
- Config/session validation and defaults.
- Auth provider HTTP flows with test servers.
- Engine helper logic and reconnect paths.
- SOCKS parsing, smux handshake, server dispatch.
- Crypto, muxconn, names, protect, logging.
- Transport frame codecs, ACK paths, KCP loopback, ffmpeg helpers.
- Memory-backed E2E tunnel tests and optional real-provider E2E matrix.
Useful commands:
```sh
go test -count=1 ./...
go test -race -count=1 ./cmd/olcrtc ./internal/app/session ./internal/config ./internal/engine/livekit
go test -race -count=1 -v ./internal/e2e
E2E_CARRIERS=wbstream E2E_TRANSPORTS=vp8channel mage e2e
go build -trimpath -o build/olcrtc ./cmd/olcrtc
```
## High-Value Coding Areas
### 1. Supervisor And Multi-Profile Failover
The first supervisor layer exists in `internal/supervisor`: the CLI can run a
prioritized list of carrier/transport profiles and move to the next profile
when the active one fails or ends.
```yaml
mode: srv
link: direct
crypto:
key_file: ./olcrtc.key
net:
dns: "1.1.1.1:53"
profiles:
- name: wb-vp8
auth:
provider: wbstream
room:
id: WB_ROOM_ID
net:
transport: vp8channel
- name: jitsi-dc
auth:
provider: jitsi
room:
id: https://meet.example.org/olcrtc-room
net:
transport: datachannel
failover:
retry_delay: 2s
max_cycles: 0
```
Implemented:
- Config schema for `profiles[]`.
- Ordered supervisor loop.
- `failover.retry_delay`.
- `failover.max_cycles`.
- Profile start/end logs.
- Planned session rotation with `lifecycle.max_session_duration`.
- Shared supervisor status snapshots with bounded failover history.
- Shared traffic wrapper with payload cap, pacing jitter, and smux frame sizing.
Still valuable:
- Health scoring per profile.
- Control-stream coordination before switching.
- Stream draining and migration instead of dropping active smux streams.
- User-facing status endpoint/export for the active profile and failover history.
Likely files:
- `internal/config/config.go`
- `internal/app/session/session.go`
- `internal/supervisor`
- `internal/server`
- `internal/client`
- `docs/configuration.md`
- `internal/e2e/tunnel_test.go`
### 2. Transport Telemetry And Adaptive Tuning
Add metrics from transport to link/session:
- Send queue depth.
- ACK latency.
- Retries.
- Reconnect count.
- Dropped/decrypt-failed frames.
- KCP RTT/loss where available.
Then make `vp8.batch_size`, `sei.fragment_size`, ACK timeout, and pacing
adaptive instead of static YAML knobs.
### 3. Control Stream Protocol
The first smux stream now carries control ping/pong after handshake. It is
still the natural place for:
- Server policy updates.
- Graceful reconnect notifications.
- Drain/start markers for failover.
- More per-session stats.
Likely files:
- `internal/control`
- `internal/server`
- `internal/client`
### 4. Destination Policy And Real Auth
The tunnel can dial arbitrary server-side TCP targets. A production wrapper
should use `AuthHook` and enforce:
- Allowed destination CIDRs/domains/ports.
- Per-device or per-plan policy.
- Session expiration.
- Traffic accounting limits.
- Sanitized rejection reasons.
This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`.
### 5. Provider Hardening
Provider APIs can drift. Worth adding:
- Central protected HTTP/WebSocket client creation with TLS 1.2+,
environment proxy support, HTTP/2 for HTTP, and bounded timeouts.
- Better typed errors from auth providers.
- Provider health probes.
- Fixture-based contract tests for API response changes.
- Per-provider rate/backoff policy.
- Safer secret/log redaction.
Likely files:
- `internal/auth/*`
- `internal/engine/*`
- `internal/carrier/builtin`
### 6. Codebase Hygiene
Some public-facing text and comments are not suitable for a serious external
project. Cleaning that up would improve maintainability and downstream trust.
The most obvious targets are top-level docs and a large hostile block comment
in `internal/transport/vp8channel/transport.go`.
## Where To Look First
| Goal | Start here |
|---|---|
| Change YAML schema | `internal/config/config.go`, `cmd/olcrtc/main.go`, docs examples. |
| Change validation/defaults | `internal/app/session/session.go`. |
| Add a new auth provider | `internal/auth`, then register in `internal/carrier/builtin/register.go`. |
| Add a new SFU protocol | `internal/engine`, then connect through auth/carrier. |
| Add a new byte transport | `internal/transport`, then register in `session.RegisterDefaults`. |
| Add link behavior above transports | `internal/link`; currently only `direct`. |
| Improve SOCKS behavior | `internal/client`. |
| Improve server target dialing or policy | `internal/server`, `pkg/olcrtc/tunnel`. |
| Improve reconnect | Engines first, then `internal/client` and `internal/server` smux rebuild behavior. |
| Improve Android app integration | `mobile`, `internal/protect`, `client.RunWithReady`. |
## Mental Model For Big Changes
Prefer to keep the layer boundaries:
- Auth creates credentials; it should not know transport details.
- Engine speaks service/SFU protocol; it should not know SOCKS or smux.
- Carrier adapts auth+engine into byte/video capabilities.
- Transport turns byte/video capabilities into reliable-ish tunnel bytes.
- Link is policy above transport.
- Client/server own SOCKS, smux, handshake, target dialing, and session hooks.
If a change crosses more than two layers, it probably deserves a new
orchestrator package instead of pushing more state into an engine or transport.

View File

@@ -16,12 +16,28 @@ room:
crypto:
# 32-byte hex (64 chars). Generate with: openssl rand -hex 32
# Or use key_file: "./olcrtc.key" to keep the secret out of this file.
key: "REPLACE_ME_WITH_64_HEX_CHARS"
net:
transport: datachannel # datachannel | videochannel | seichannel | vp8channel
dns: "8.8.8.8:53"
liveness:
interval: 10s
timeout: 5s
failures: 3
# Optional planned rebuild for long-running calls.
# lifecycle:
# max_session_duration: 6h
# Optional reliability shaping for encrypted wire messages.
# traffic:
# max_payload_size: 4096
# min_delay: 5ms
# max_delay: 30ms
# Outbound SOCKS5 proxy for server-side egress (optional)
socks:
proxy_addr: "" # e.g. "127.0.0.1"

View File

@@ -48,7 +48,7 @@
| `auth.provider` | `telemost`, `jazz`, `wbstream` или `jitsi` |
| `net.transport` | `datachannel`, `vp8channel`, `seichannel` или `videochannel` |
| `room.id` | Room ID |
| `crypto.key` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` |
| `crypto.key` или `crypto.key_file` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` |
| `link` | Всегда `direct` |
| `data` | Всегда `data` |
| `net.dns` | DNS-сервер, например `1.1.1.1:53` |
@@ -60,18 +60,52 @@
| YAML поле | Описание |
|-----------|----------|
| `debug` | `true` для подробных логов соединений |
| `profiles` | Список профилей failover для `srv`/`cnc` |
| `failover.retry_delay` | Пауза перед следующим профилем, например `2s` |
| `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно |
| `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` |
| `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` |
| `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` |
| `lifecycle.max_session_duration` | Плановый rebuild сессии после указанного времени, например `6h`; если поле не задано, выключено |
| `traffic.max_payload_size` | Лимит размера зашифрованного wire-message; `0` = лимит транспорта |
| `traffic.min_delay` / `.max_delay` | Необязательный pacing отправки, например `5ms` / `30ms` |
`crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно.
Если задан `profiles`, поля верхнего уровня становятся общими defaults, а
каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и
настройки транспорта/liveness. Порядок профилей должен совпадать на сервере и
клиенте.
`liveness` проверяет именно зашифрованный smux control stream после handshake,
а не только статус WebRTC/provider соединения. Если pong не приходит несколько
раз подряд, текущая smux-сессия пересоздается.
`lifecycle.max_session_duration` ограничивает длительность одного звонка /
provider session. Когда таймер истекает, текущая `srv` или `cnc` сессия
закрывается и стартует заново с тем же конфигом. Пока эта настройка включена,
чистое завершение сессии тоже перезапускается, чтобы второй peer мог догнать
плановый rebuild. Формат значения: `30m`, `2h`, `6h`; `0s` и отрицательные
значения не принимаются.
`traffic` добавляет общий wrapper над выбранным transport. Он может ограничить
размер зашифрованного сообщения и добавить небольшую задержку перед отправкой.
Данные не обрезаются: если сообщение не помещается в эффективный лимит, send
возвращает явную ошибку. При заданном `max_payload_size` smux frame size также
уменьшается с учетом crypto overhead; при `0` остается лимит выбранного
transport. Используй одинаковые traffic-настройки на обеих сторонах.
---
## mode: gen
Генерирует Room ID заранее, не запуская сервер. Поддерживается только для `jazz`. Для `wbstream` создавай руму вручную через [stream.wb.ru](https://stream.wb.ru) (автогенерация отключена со стороны WB).
Генерирует Room ID заранее, не запуская сервер. Поддерживается для auth-провайдеров с автосозданием комнат: `jazz` и `wbstream`. Для `telemost` комнату нужно создавать вручную через сайт.
**Обязательные поля:**
| YAML поле | Описание |
|-----------|----------|
| `auth.provider` | `jazz` |
| `auth.provider` | `jazz` или `wbstream` |
| `net.dns` | DNS-сервер |
| `gen.amount` | Количество комнат |
@@ -79,7 +113,7 @@
# gen.yaml
mode: gen
auth:
provider: jazz
provider: wbstream
net:
dns: "1.1.1.1:53"
gen:
@@ -116,6 +150,9 @@ gen:
Если `socks.user` не задан - аутентификация отключена (любой локальный клиент может подключиться).
Если задан - клиент принимает только подключения с правильным логином и паролем (RFC 1929).
Если `socks.host` не loopback (`127.0.0.1`, `::1`, `localhost`), `socks.user` и `socks.pass` обязательны.
Это защита от случайного открытого SOCKS5-прокси в локальной сети или интернете.
---
## datachannel

View File

@@ -5,13 +5,17 @@ import (
"context"
"errors"
"fmt"
"net"
"slices"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/auth"
"github.com/openlibrecommunity/olcrtc/internal/carrier"
"github.com/openlibrecommunity/olcrtc/internal/carrier/builtin"
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/link/direct"
"github.com/openlibrecommunity/olcrtc/internal/logger"
@@ -37,18 +41,35 @@ const (
videoCodecTile = "tile"
)
const (
defaultVideoWidth = 1920
defaultVideoHeight = 1080
defaultVideoFPS = 30
defaultVideoBitrate = "2M"
defaultVideoHW = "none"
defaultVideoQRRecovery = "low"
defaultVP8FPS = 25
defaultVP8BatchSize = 1
defaultSEIFPS = 60
defaultSEIBatchSize = 64
defaultSEIFragmentSize = 900
defaultSEIAckTimeoutMS = 2000
)
var sessionRestartDelay = 2 * time.Second
var (
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
ErrRoomIDRequired = errors.New("room ID required (set room.id)")
// ErrModeRequired indicates that mode is not one of the supported values.
ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)")
// ErrAmountRequired indicates that -amount is required for gen mode.
ErrAmountRequired = errors.New("amount required for gen mode (use -amount <n>)")
ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)")
// ErrAmountRequired indicates that gen.amount is required for gen mode.
ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)")
// ErrAuthRequired indicates that no auth provider was selected.
ErrAuthRequired = errors.New(
"auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)")
// ErrURLRequired indicates that -url must be provided when the auth provider has no default URL.
ErrURLRequired = errors.New("SFU URL required (use -url wss://...)")
"auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)")
// ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL.
ErrURLRequired = errors.New("SFU URL required (set auth.url)")
// ErrUnsupportedCarrier indicates that carrier is not registered.
ErrUnsupportedCarrier = errors.New("unsupported carrier")
// ErrUnsupportedLink indicates that link is not registered.
@@ -57,88 +78,119 @@ var (
ErrUnsupportedTransport = errors.New("unsupported transport")
// ErrLinkRequired indicates that link is not provided.
ErrLinkRequired = errors.New("link required (use -link direct)")
ErrLinkRequired = errors.New("link required (set link to direct)")
// ErrTransportRequired indicates that transport is not provided.
ErrTransportRequired = errors.New(
"transport required (use -transport datachannel, -transport videochannel, " +
"-transport seichannel or -transport vp8channel)")
"transport required (set transport to datachannel, videochannel, seichannel or vp8channel)")
// ErrKeyRequired indicates that encryption key is not provided.
ErrKeyRequired = errors.New("key required (use -key <hex>)")
ErrKeyRequired = errors.New("key required (set crypto.key)")
// ErrDNSServerRequired indicates that dns server is not provided.
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
ErrDNSServerRequired = errors.New("dns server required (set net.dns)")
// ErrVideoWidthRequired indicates that video width is required for videochannel.
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)")
// ErrVideoHeightRequired indicates that video height is required for videochannel.
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)")
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)")
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
ErrVideoBitrateRequired = errors.New(
"video bitrate required for videochannel (use -video-bitrate)")
"video bitrate required for videochannel (set video.bitrate)")
// ErrVideoHWRequired indicates that video hardware acceleration is required.
ErrVideoHWRequired = errors.New(
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
"video hardware acceleration required for videochannel (set video.hw to none or nvenc)")
// ErrVideoCodecInvalid indicates that the video codec is not valid.
ErrVideoCodecInvalid = errors.New(
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
"invalid video codec for videochannel (set video.codec to qrcode or tile)")
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080")
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)")
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)")
// ErrSEIFPSRequired indicates that seichannel fps is required.
ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)")
ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)")
// ErrSEIBatchSizeRequired indicates that seichannel batch size is required.
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)")
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)")
// ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required.
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)")
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)")
// ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required.
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)")
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)")
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)")
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)")
// ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication.
ErrSOCKSAuthRequired = errors.New(
"socks auth required when binding outside loopback (set socks.user and socks.pass)")
// ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration.
ErrLivenessIntervalInvalid = errors.New(
"invalid liveness interval (set liveness.interval to a duration > 0)")
// ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration.
ErrLivenessTimeoutInvalid = errors.New(
"invalid liveness timeout (set liveness.timeout to a duration > 0)")
// ErrLivenessFailuresInvalid indicates that liveness.failures is not positive.
ErrLivenessFailuresInvalid = errors.New(
"invalid liveness failures (set liveness.failures to a value > 0)")
// ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration.
ErrLifecycleMaxSessionDurationInvalid = errors.New(
"invalid max session duration (set lifecycle.max_session_duration to a duration > 0)")
// ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid.
ErrTrafficMaxPayloadSizeInvalid = errors.New(
"invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)")
// ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration.
ErrTrafficMinDelayInvalid = errors.New(
"invalid traffic min delay (set traffic.min_delay to a duration >= 0)")
// ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration.
ErrTrafficMaxDelayInvalid = errors.New(
"invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)")
)
// Config holds runtime session settings.
type Config struct {
Mode string
Link string
Transport string
Auth string
Engine string
URL string
Token string
RoomID string
KeyHex string
SOCKSHost string
SOCKSPort int
SOCKSUser string
SOCKSPass string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Amount int
Mode string
Link string
Transport string
Auth string
Engine string
URL string
Token string
RoomID string
KeyHex string
SOCKSHost string
SOCKSPort int
SOCKSUser string
SOCKSPass string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
LivenessInterval string
LivenessTimeout string
LivenessFailures int
MaxSessionDuration string
TrafficMaxPayloadSize int
TrafficMinDelay string
TrafficMaxDelay string
Amount int
}
// RegisterDefaults registers built-in carriers and transports.
@@ -180,6 +232,94 @@ func ApplyAuthDefaults(cfg Config) (Config, error) {
return cfg, nil
}
// ApplyTransportDefaults fills documented transport defaults without changing core routing fields.
func ApplyTransportDefaults(cfg Config) Config {
switch cfg.Transport {
case transportVideo:
return applyVideoDefaults(cfg)
case transportVP8:
return applyVP8Defaults(cfg)
case transportSEI:
return applySEIDefaults(cfg)
default:
return cfg
}
}
// ApplyLivenessDefaults fills documented control-stream liveness defaults.
func ApplyLivenessDefaults(cfg Config) Config {
if cfg.LivenessInterval == "" {
cfg.LivenessInterval = control.DefaultInterval.String()
}
if cfg.LivenessTimeout == "" {
cfg.LivenessTimeout = control.DefaultTimeout.String()
}
if cfg.LivenessFailures == 0 {
cfg.LivenessFailures = control.DefaultFailures
}
return cfg
}
func applyVideoDefaults(cfg Config) Config {
if cfg.VideoCodec == "" {
cfg.VideoCodec = videoCodecQRCode
}
if cfg.VideoCodec == videoCodecTile {
if cfg.VideoWidth == 0 {
cfg.VideoWidth = 1080
}
if cfg.VideoHeight == 0 {
cfg.VideoHeight = 1080
}
} else {
if cfg.VideoWidth == 0 {
cfg.VideoWidth = defaultVideoWidth
}
if cfg.VideoHeight == 0 {
cfg.VideoHeight = defaultVideoHeight
}
}
if cfg.VideoFPS == 0 {
cfg.VideoFPS = defaultVideoFPS
}
if cfg.VideoBitrate == "" {
cfg.VideoBitrate = defaultVideoBitrate
}
if cfg.VideoHW == "" {
cfg.VideoHW = defaultVideoHW
}
if cfg.VideoQRRecovery == "" {
cfg.VideoQRRecovery = defaultVideoQRRecovery
}
return cfg
}
func applyVP8Defaults(cfg Config) Config {
if cfg.VP8FPS == 0 {
cfg.VP8FPS = defaultVP8FPS
}
if cfg.VP8BatchSize == 0 {
cfg.VP8BatchSize = defaultVP8BatchSize
}
return cfg
}
func applySEIDefaults(cfg Config) Config {
if cfg.SEIFPS == 0 {
cfg.SEIFPS = defaultSEIFPS
}
if cfg.SEIBatchSize == 0 {
cfg.SEIBatchSize = defaultSEIBatchSize
}
if cfg.SEIFragmentSize == 0 {
cfg.SEIFragmentSize = defaultSEIFragmentSize
}
if cfg.SEIAckTimeoutMS == 0 {
cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS
}
return cfg
}
// Validate verifies that the runtime config refers to registered components and all required fields are present.
func Validate(cfg Config) error {
if err := validateMode(cfg); err != nil {
@@ -200,6 +340,15 @@ func Validate(cfg Config) error {
if err := validateTransportConfig(cfg); err != nil {
return err
}
if err := validateLivenessConfig(cfg); err != nil {
return err
}
if err := validateLifecycleConfig(cfg); err != nil {
return err
}
if err := validateTrafficConfig(cfg); err != nil {
return err
}
return validateModeConfig(cfg)
}
@@ -333,13 +482,163 @@ func validateModeConfig(cfg Config) error {
if cfg.SOCKSPort == 0 {
return ErrSOCKSPortRequired
}
if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") {
return ErrSOCKSAuthRequired
}
return nil
}
func validateLivenessConfig(cfg Config) error {
if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil {
return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
}
if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil {
return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
}
if cfg.LivenessFailures < 0 {
return ErrLivenessFailuresInvalid
}
return nil
}
func validateLifecycleConfig(cfg Config) error {
if _, err := maxSessionDuration(cfg); err != nil {
return err
}
return nil
}
func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) {
if value == "" {
return def, nil
}
d, err := time.ParseDuration(value)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be > 0")
}
return d, nil
}
func livenessConfig(cfg Config) (control.Config, error) {
interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval)
if err != nil {
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
}
timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout)
if err != nil {
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
}
failures := cfg.LivenessFailures
if failures == 0 {
failures = control.DefaultFailures
}
if failures < 0 {
return control.Config{}, ErrLivenessFailuresInvalid
}
return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil
}
func maxSessionDuration(cfg Config) (time.Duration, error) {
if cfg.MaxSessionDuration == "" {
return 0, nil
}
d, err := time.ParseDuration(cfg.MaxSessionDuration)
if err != nil {
return 0, fmt.Errorf("%w: %v", ErrLifecycleMaxSessionDurationInvalid, err)
}
if d <= 0 {
return 0, ErrLifecycleMaxSessionDurationInvalid
}
return d, nil
}
func validateTrafficConfig(cfg Config) error {
_, err := trafficConfig(cfg)
return err
}
func trafficConfig(cfg Config) (transport.TrafficConfig, error) {
if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 &&
cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) {
return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid
}
minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay)
if err != nil {
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err)
}
maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay)
if err != nil {
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err)
}
if maxDelay > 0 && maxDelay < minDelay {
return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid
}
return transport.TrafficConfig{
MaxPayloadSize: cfg.TrafficMaxPayloadSize,
MinDelay: minDelay,
MaxDelay: maxDelay,
}, nil
}
func parseOptionalNonNegativeDuration(value string) (time.Duration, error) {
if value == "" {
return 0, nil
}
d, err := time.ParseDuration(value)
if err != nil {
return 0, err
}
if d < 0 {
return 0, fmt.Errorf("duration must be >= 0")
}
return d, nil
}
func isLoopbackListenHost(host string) bool {
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
// Run starts the configured mode.
func Run(ctx context.Context, cfg Config) error {
cfg = ApplyTransportDefaults(cfg)
cfg = ApplyLivenessDefaults(cfg)
roomURL := cfg.RoomID
liveness, err := livenessConfig(cfg)
if err != nil {
return err
}
maxDuration, err := maxSessionDuration(cfg)
if err != nil {
return err
}
traffic, err := trafficConfig(cfg)
if err != nil {
return err
}
run := func(ctx context.Context) error {
return runOnce(ctx, cfg, roomURL, liveness, traffic)
}
if maxDuration > 0 {
return runWithSessionRotation(ctx, maxDuration, run)
}
return run(ctx)
}
func runOnce(
ctx context.Context,
cfg Config,
roomURL string,
liveness control.Config,
traffic transport.TrafficConfig,
) error {
switch cfg.Mode {
case modeSRV:
if err := server.Run(ctx, server.Config{
@@ -370,6 +669,8 @@ func Run(ctx context.Context, cfg Config) error {
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
Liveness: liveness,
Traffic: traffic,
OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) {
logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims)
},
@@ -413,6 +714,8 @@ func Run(ctx context.Context, cfg Config) error {
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
Liveness: liveness,
Traffic: traffic,
}); err != nil {
return fmt.Errorf("client: %w", err)
}
@@ -422,6 +725,52 @@ func Run(ctx context.Context, cfg Config) error {
}
}
func runWithSessionRotation(ctx context.Context, maxDuration time.Duration, run func(context.Context) error) error {
for cycle := 1; ; cycle++ {
currentCycle := cycle
runCtx, cancel := context.WithCancel(ctx)
var rotated atomic.Bool
timer := time.AfterFunc(maxDuration, func() {
rotated.Store(true)
logger.Infof("session max duration reached: duration=%s cycle=%d", maxDuration, currentCycle)
cancel()
})
err := run(runCtx)
cancel()
timer.Stop()
if ctx.Err() != nil {
return nil
}
if rotated.Load() {
if err != nil {
logger.Warnf("session rotation ended with error: cycle=%d err=%v", currentCycle, err)
}
logger.Infof("session rotation restarting: next_cycle=%d", currentCycle+1)
if err := waitSessionRestart(ctx); err != nil {
return nil
}
continue
}
if err != nil {
return err
}
logger.Infof("session ended cleanly with lifecycle rotation enabled: next_cycle=%d", currentCycle+1)
if err := waitSessionRestart(ctx); err != nil {
return nil
}
}
}
func waitSessionRestart(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(sessionRestartDelay):
return nil
}
}
// ValidateGen validates that the config contains enough fields to run gen mode.
func ValidateGen(cfg Config) error {
if cfg.Auth == "" {

View File

@@ -3,9 +3,136 @@ package session
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
)
func TestApplyTransportDefaults(t *testing.T) {
tests := []struct {
name string
in Config
want Config
}{
{
name: "vp8",
in: Config{Transport: transportVP8},
want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1},
},
{
name: "sei",
in: Config{Transport: transportSEI},
want: Config{
Transport: transportSEI,
SEIFPS: 60,
SEIBatchSize: 64,
SEIFragmentSize: 900,
SEIAckTimeoutMS: 2000,
},
},
{
name: "video qrcode",
in: Config{Transport: transportVideo},
want: Config{
Transport: transportVideo,
VideoWidth: 1920,
VideoHeight: 1080,
VideoFPS: 30,
VideoBitrate: "2M",
VideoHW: "none",
VideoQRRecovery: "low",
VideoCodec: videoCodecQRCode,
},
},
{
name: "video tile dimensions",
in: Config{Transport: transportVideo, VideoCodec: videoCodecTile},
want: Config{
Transport: transportVideo,
VideoWidth: 1080,
VideoHeight: 1080,
VideoFPS: 30,
VideoBitrate: "2M",
VideoHW: "none",
VideoQRRecovery: "low",
VideoCodec: videoCodecTile,
},
},
{
name: "keeps explicit values",
in: Config{
Transport: transportSEI,
SEIFPS: 10,
SEIBatchSize: 2,
SEIFragmentSize: 300,
SEIAckTimeoutMS: 1500,
},
want: Config{
Transport: transportSEI,
SEIFPS: 10,
SEIBatchSize: 2,
SEIFragmentSize: 300,
SEIAckTimeoutMS: 1500,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ApplyTransportDefaults(tt.in)
if got != tt.want {
t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want)
}
})
}
}
func TestApplyLivenessDefaults(t *testing.T) {
got := ApplyLivenessDefaults(Config{})
if got.LivenessInterval != control.DefaultInterval.String() {
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
}
if got.LivenessTimeout != control.DefaultTimeout.String() {
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
}
if got.LivenessFailures != control.DefaultFailures {
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
}
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
if got := ApplyLivenessDefaults(explicit); got != explicit {
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
}
}
func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) {
oldRestartDelay := sessionRestartDelay
sessionRestartDelay = time.Millisecond
t.Cleanup(func() { sessionRestartDelay = oldRestartDelay })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var calls atomic.Int32
err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error {
if calls.Add(1) >= 2 {
cancel()
return nil
}
<-ctx.Done()
return nil
})
if err != nil {
t.Fatalf("runWithSessionRotation() error = %v", err)
}
if got := calls.Load(); got < 2 {
t.Fatalf("run calls = %d, want at least 2", got)
}
}
//nolint:maintidx // table-driven validation test naturally has many cases
func TestValidate(t *testing.T) {
RegisterDefaults()
@@ -310,6 +437,148 @@ func TestValidate(t *testing.T) {
}(),
want: ErrSOCKSPortRequired,
},
{
name: "cnc rejects unauthenticated wildcard socks bind",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "0.0.0.0"
cfg.SOCKSPort = 1080
return cfg
}(),
want: ErrSOCKSAuthRequired,
},
{
name: "cnc allows authenticated wildcard socks bind",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "0.0.0.0"
cfg.SOCKSPort = 1080
cfg.SOCKSUser = "user"
cfg.SOCKSPass = "pass"
return cfg
}(),
},
{
name: "cnc allows localhost socks bind without auth",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "localhost"
cfg.SOCKSPort = 1080
return cfg
}(),
},
{
name: "liveness rejects bad interval",
cfg: func() Config {
cfg := base
cfg.LivenessInterval = "nope"
return cfg
}(),
want: ErrLivenessIntervalInvalid,
},
{
name: "liveness rejects zero timeout",
cfg: func() Config {
cfg := base
cfg.LivenessTimeout = "0s"
return cfg
}(),
want: ErrLivenessTimeoutInvalid,
},
{
name: "liveness rejects negative failures",
cfg: func() Config {
cfg := base
cfg.LivenessFailures = -1
return cfg
}(),
want: ErrLivenessFailuresInvalid,
},
{
name: "lifecycle accepts max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "1h"
return cfg
}(),
},
{
name: "lifecycle rejects bad max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "nope"
return cfg
}(),
want: ErrLifecycleMaxSessionDurationInvalid,
},
{
name: "lifecycle rejects zero max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "0s"
return cfg
}(),
want: ErrLifecycleMaxSessionDurationInvalid,
},
{
name: "traffic accepts shaping",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = 4096
cfg.TrafficMinDelay = "5ms"
cfg.TrafficMaxDelay = "30ms"
return cfg
}(),
},
{
name: "traffic rejects negative max payload",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = -1
return cfg
}(),
want: ErrTrafficMaxPayloadSizeInvalid,
},
{
name: "traffic rejects payload smaller than crypto overhead",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = crypto.WireOverhead
return cfg
}(),
want: ErrTrafficMaxPayloadSizeInvalid,
},
{
name: "traffic rejects bad min delay",
cfg: func() Config {
cfg := base
cfg.TrafficMinDelay = "nope"
return cfg
}(),
want: ErrTrafficMinDelayInvalid,
},
{
name: "traffic rejects negative max delay",
cfg: func() Config {
cfg := base
cfg.TrafficMaxDelay = "-1ms"
return cfg
}(),
want: ErrTrafficMaxDelayInvalid,
},
{
name: "traffic rejects max delay below min delay",
cfg: func() Config {
cfg := base
cfg.TrafficMinDelay = "30ms"
cfg.TrafficMaxDelay = "5ms"
return cfg
}(),
want: ErrTrafficMaxDelayInvalid,
},
}
for _, tt := range tests {

View File

@@ -9,9 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/protect"
@@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, statusError(errCreateRoomFailed, resp)
return nil, protect.StatusError(errCreateRoomFailed, resp, 1024)
}
var res createResponse
@@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
defer func() { _ = preResp.Body.Close() }()
if preResp.StatusCode != http.StatusOK {
return "", statusError(errPreconnectFailed, preResp)
return "", protect.StatusError(errPreconnectFailed, preResp, 1024)
}
var preconnectResp struct {
@@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
return preconnectResp.ConnectorURL, nil
}
func statusError(base error, resp *http.Response) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
bodyText := strings.TrimSpace(string(body))
if bodyText == "" {
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
}
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
}
func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) {
headers := anonymousHeaders()
connectorURL, err := preconnect(ctx, roomID, password, headers)

View File

@@ -11,7 +11,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
@@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
return nil, protect.StatusError(ErrAPI, resp, 4096)
}
var info ConnectionInfo

View File

@@ -10,7 +10,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/openlibrecommunity/olcrtc/internal/protect"
@@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b)
return "", protect.StatusError(errGuestRegister, resp, 4096)
}
var res guestRegisterResponse
@@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b)
return "", protect.StatusError(errCreateRoom, resp, 4096)
}
var res createRoomResponse
@@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b)
return protect.StatusError(errJoinRoom, resp, 4096)
}
return nil
}
@@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b)
return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096)
}
var res tokenResponse

View File

@@ -17,12 +17,14 @@ import (
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
@@ -54,7 +56,12 @@ type Client struct {
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
reconnectMu sync.Mutex
healthMu sync.RWMutex
health control.Status
onHealth HealthFunc
deviceID string
sessionID string
claims map[string]any
@@ -63,6 +70,9 @@ type Client struct {
socksPass string
}
// HealthFunc is called when the client control health snapshot changes.
type HealthFunc func(control.Status)
// Config holds runtime configuration for [Run] and [RunWithReady].
type Config struct {
Link string
@@ -93,6 +103,8 @@ type Config struct {
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
@@ -106,6 +118,9 @@ type Config struct {
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
// OnHealth receives liveness/reconnect status updates. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the client with the given configuration.
@@ -135,6 +150,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
onHealth: cfg.OnHealth,
}
// shutdown is registered BEFORE bringUpLink so we always close any
@@ -202,6 +218,7 @@ func (c *Client) bringUpLink(
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
@@ -217,7 +234,9 @@ func (c *Client) bringUpLink(
if ctx.Err() != nil {
return
}
c.handleReconnect()
if !c.handleReconnect(ctx, cfg, cancel, "carrier") {
cancel()
}
})
if err := ln.Connect(ctx); err != nil {
@@ -225,7 +244,7 @@ func (c *Client) bringUpLink(
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln)))
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
@@ -243,14 +262,16 @@ func (c *Client) bringUpLink(
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session — the server
// holds it parked, and it would carry future control messages.
// The stream stays open for the lifetime of the smux session and carries
// post-handshake control messages.
func openControlStream(
sess *smux.Session,
deviceID string,
@@ -314,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) {
}
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
func smuxConfig(maxWirePayload ...int) *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
if maxFrameSize < cfg.MaxFrameSize {
cfg.MaxFrameSize = maxFrameSize
}
}
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
@@ -326,8 +353,20 @@ func smuxConfig() *smux.Config {
return cfg
}
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect - tearing down smux session")
func linkMaxPayload(ln link.Link) int {
provider, ok := ln.(link.FeaturesProvider)
if !ok {
return 0
}
return provider.Features().MaxPayloadSize
}
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool {
c.reconnectMu.Lock()
defer c.reconnectMu.Unlock()
c.recordReconnect()
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
@@ -336,14 +375,19 @@ func (c *Client) handleReconnect() {
c.sessMu.Lock()
oldControl := c.controlStrm
oldControlStop := c.controlStop
oldSess := c.session
oldConn := c.conn
c.conn = newConn
c.session = nil
c.controlStrm = nil
c.controlStop = nil
c.sessionID = ""
c.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldControl != nil {
_ = oldControl.Close()
}
@@ -364,15 +408,26 @@ func (c *Client) handleReconnect() {
attemptDelay = 300 * time.Millisecond
)
for attempt := 1; attempt <= maxAttempts; attempt++ {
if c.tryReopenSession(attempt) {
return
logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason)
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(attemptDelay):
}
time.Sleep(attemptDelay)
}
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
return false
}
func (c *Client) tryReopenSession(attempt int) bool {
func (c *Client) tryReopenSession(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
attempt int,
) bool {
conn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
@@ -383,7 +438,7 @@ func (c *Client) tryReopenSession(attempt int) bool {
_ = old.Close()
}
sess, err := smux.Client(conn, smuxConfig())
sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln)))
if err != nil {
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
@@ -400,19 +455,138 @@ func (c *Client) tryReopenSession(attempt int) bool {
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
return true
}
func (c *Client) startControlLoop(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
stream *smux.Stream,
) {
controlCtx, stop := context.WithCancel(ctx)
c.sessMu.Lock()
c.controlStop = stop
c.sessMu.Unlock()
liveness := cfg.Liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
c.sessMu.RLock()
sid := c.sessionID
c.sessMu.RUnlock()
c.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
c.recordMissed(missed)
logger.Warnf("control missed pong on client: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
c.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
go func() {
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("client control stream ended: %v", err)
}
if !c.handleReconnect(ctx, cfg, cancel, "liveness") {
cancel()
}
}()
}
// Status returns the latest client-side control health snapshot.
func (c *Client) Status() control.Status {
c.healthMu.RLock()
defer c.healthMu.RUnlock()
return c.health
}
func (c *Client) recordSession(sessionID string) {
c.healthMu.Lock()
c.health.SessionID = sessionID
c.health.MissedPongs = 0
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordPong(h control.Health) {
c.healthMu.Lock()
c.health.LastPong = h.LastSeen
c.health.LastRTT = h.RTT
c.health.MissedPongs = 0
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordMissed(missed int) {
c.healthMu.Lock()
c.health.MissedPongs = missed
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordUnhealthy(missed int) {
c.healthMu.Lock()
c.health.MissedPongs = missed
c.health.UnhealthyEvents++
c.health.LastUnhealthy = time.Now()
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordReconnect() {
c.healthMu.Lock()
c.health.Reconnects++
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) notifyHealth(status control.Status) {
if c.onHealth != nil {
c.onHealth(status)
}
}
func (c *Client) shutdown() {
c.sessMu.Lock()
control := c.controlStrm
controlStop := c.controlStop
sess := c.session
conn := c.conn
c.controlStrm = nil
c.controlStop = nil
c.session = nil
c.conn = nil
c.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if conn != nil {
_ = conn.Close()
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
@@ -48,6 +49,11 @@ func TestSmuxConfig(t *testing.T) {
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
capped := smuxConfig(4096)
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
}
}
func TestSocks5Handshake(t *testing.T) {
@@ -517,3 +523,96 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
t.Fatal("shutdown() did not close link")
}
}
func TestStartControlLoopReportsPong(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
peerStreamCh := make(chan *smux.Stream, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
peerStreamCh <- stream
}
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
peerStream := <-peerStreamCh
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
got := make(chan control.Health, 1)
c := &Client{sessionID: "sid-control"}
c.recordSession("sid-control")
c.startControlLoop(ctx, Config{
Liveness: control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h control.Health) {
select {
case got <- h:
default:
}
},
},
}, cancel, stream)
go func() {
_ = control.Run(ctx, peerStream, control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
})
}()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for control pong")
}
status := c.Status()
if status.SessionID != "sid-control" {
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
}
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
t.Fatalf("Status() = %+v", status)
}
}
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
updates := 0
c := &Client{onHealth: func(control.Status) { updates++ }}
c.recordSession("sid-1")
c.recordMissed(2)
c.recordUnhealthy(3)
c.recordReconnect()
status := c.Status()
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
t.Fatalf("Status() = %+v", status)
}
if updates != 4 {
t.Fatalf("health updates = %d, want 4", updates)
}
}

View File

@@ -1,10 +1,9 @@
// Package config loads olcrtc runtime configuration from YAML files.
//
// The YAML schema mirrors [session.Config]. Fields left unset in the file
// remain at their zero value, allowing CLI flags to fill them in. Use
// [Apply] to merge a parsed [File] onto an existing [session.Config];
// non-zero fields in the session config (typically populated from CLI flags)
// take precedence over the YAML values.
// remain at their zero value. Use [Apply] to map a parsed [File] onto an
// existing [session.Config]; non-zero fields in the session config take
// precedence over the YAML values.
//
//nolint:tagliatelle // YAML keys are the documented config file schema.
package config
@@ -13,31 +12,68 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"gopkg.in/yaml.v3"
)
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
var ErrConfigNotFound = errors.New("config file not found")
var (
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
ErrConfigNotFound = errors.New("config file not found")
// ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured.
ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set")
// ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file.
ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty")
)
// File is the on-disk YAML schema.
type File struct {
Mode string `yaml:"mode"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Gen Gen `yaml:"gen"`
Data string `yaml:"data"`
Debug bool `yaml:"debug"`
FFmpeg string `yaml:"ffmpeg"`
Mode string `yaml:"mode"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Liveness Liveness `yaml:"liveness"`
Lifecycle Lifecycle `yaml:"lifecycle"`
Traffic Traffic `yaml:"traffic"`
Gen Gen `yaml:"gen"`
Profiles []Profile `yaml:"profiles"`
Failover Failover `yaml:"failover"`
Data string `yaml:"data"`
Debug bool `yaml:"debug"`
FFmpeg string `yaml:"ffmpeg"`
}
// Profile is a failover entry that overrides top-level runtime fields.
type Profile struct {
Name string `yaml:"name"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Liveness Liveness `yaml:"liveness"`
Lifecycle Lifecycle `yaml:"lifecycle"`
Traffic Traffic `yaml:"traffic"`
}
// Failover controls ordered profile failover.
type Failover struct {
RetryDelay string `yaml:"retry_delay"`
MaxCycles int `yaml:"max_cycles"`
}
// Auth selects the auth provider.
@@ -52,7 +88,8 @@ type Room struct {
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
type Crypto struct {
Key string `yaml:"key"` // 64-char hex (32 bytes)
Key string `yaml:"key"` // 64-char hex (32 bytes)
KeyFile string `yaml:"key_file"` // path to a file containing crypto.key
}
// Net groups network and transport selection.
@@ -106,6 +143,25 @@ type SEI struct {
AckTimeoutMS int `yaml:"ack_timeout_ms"`
}
// Liveness tunes the post-handshake control stream ping/pong checks.
type Liveness struct {
Interval string `yaml:"interval"`
Timeout string `yaml:"timeout"`
Failures int `yaml:"failures"`
}
// Lifecycle controls planned session rebuilds.
type Lifecycle struct {
MaxSessionDuration string `yaml:"max_session_duration"`
}
// Traffic controls optional reliability-oriented send shaping.
type Traffic struct {
MaxPayloadSize int `yaml:"max_payload_size"`
MinDelay string `yaml:"min_delay"`
MaxDelay string `yaml:"max_delay"`
}
// Gen controls room-generation mode.
type Gen struct {
Amount int `yaml:"amount"`
@@ -125,9 +181,63 @@ func Load(path string) (File, error) {
if err := yaml.Unmarshal(data, &f); err != nil {
return File{}, fmt.Errorf("parse config %s: %w", path, err)
}
if err := loadExternalSecrets(path, &f); err != nil {
return File{}, err
}
return f, nil
}
func loadExternalSecrets(configPath string, f *File) error {
if f.Crypto.KeyFile == "" {
return loadProfileSecrets(configPath, f.Profiles)
}
if f.Crypto.Key != "" {
return ErrCryptoKeyConflict
}
key, err := readKeyFile(configPath, f.Crypto.KeyFile)
if err != nil {
return err
}
f.Crypto.Key = key
return loadProfileSecrets(configPath, f.Profiles)
}
func loadProfileSecrets(configPath string, profiles []Profile) error {
for i := range profiles {
if profiles[i].Crypto.KeyFile == "" {
continue
}
if profiles[i].Crypto.Key != "" {
return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict)
}
key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile)
if err != nil {
return fmt.Errorf("profiles[%d]: %w", i, err)
}
profiles[i].Crypto.Key = key
}
return nil
}
func readKeyFile(configPath, keyFile string) (string, error) {
keyPath := keyFile
if !filepath.IsAbs(keyPath) {
keyPath = filepath.Join(filepath.Dir(configPath), keyPath)
}
// #nosec G304 -- key_file is an explicit path in the user's config file.
data, err := os.ReadFile(keyPath)
if err != nil {
return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err)
}
key := strings.TrimSpace(string(data))
if key == "" {
return "", ErrCryptoKeyFileEmpty
}
return key, nil
}
// Apply merges f onto dst. CLI-set fields (non-zero values in dst) win;
// YAML values fill in the rest.
func Apply(dst session.Config, f File) session.Config {
@@ -163,10 +273,61 @@ func Apply(dst session.Config, f File) session.Config {
dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize)
dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize)
dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS)
dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval)
dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout)
dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures)
dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration)
dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize)
dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay)
dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay)
dst.Amount = pickInt(dst.Amount, f.Gen.Amount)
return dst
}
// ApplyProfile overlays a failover profile onto an already-applied base config.
func ApplyProfile(base session.Config, p Profile) session.Config {
dst := base
dst.Link = overlayString(dst.Link, p.Link)
dst.Transport = overlayString(dst.Transport, p.Net.Transport)
dst.Auth = overlayString(dst.Auth, p.Auth.Provider)
dst.Engine = overlayString(dst.Engine, p.Engine.Name)
dst.URL = overlayString(dst.URL, p.Engine.URL)
dst.Token = overlayString(dst.Token, p.Engine.Token)
dst.RoomID = overlayString(dst.RoomID, p.Room.ID)
dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key)
dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host)
dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port)
dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User)
dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass)
dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS)
dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr)
dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort)
dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width)
dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height)
dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS)
dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate)
dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW)
dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize)
dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery)
dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec)
dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule)
dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS)
dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS)
dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize)
dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS)
dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize)
dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize)
dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS)
dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval)
dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout)
dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures)
dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration)
dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize)
dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay)
dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay)
return dst
}
func pickString(cli, yamlVal string) string {
if cli != "" {
return cli
@@ -180,3 +341,17 @@ func pickInt(cli, yamlVal int) int {
}
return yamlVal
}
func overlayString(base, override string) string {
if override != "" {
return override
}
return base
}
func overlayInt(base, override int) int {
if override != 0 {
return override
}
return base
}

View File

@@ -1,6 +1,7 @@
package config
import (
"errors"
"os"
"path/filepath"
"testing"
@@ -38,6 +39,16 @@ socks:
vp8:
fps: 25
batch_size: 4
liveness:
interval: 2s
timeout: 500ms
failures: 4
lifecycle:
max_session_duration: 6h
traffic:
max_payload_size: 4096
min_delay: 5ms
max_delay: 30ms
gen:
amount: 3
debug: true
@@ -75,20 +86,27 @@ func requireLoadedFile(t *testing.T, f File) {
func requireAppliedConfig(t *testing.T, got session.Config) {
t.Helper()
want := session.Config{
Mode: testModeSrv,
Link: "direct",
Auth: testAuthProvider,
RoomID: testRoomID,
KeyHex: testCryptoKey,
Transport: "datachannel",
DNSServer: "1.1.1.1:53",
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
SOCKSUser: "u",
SOCKSPass: "p",
VP8FPS: 25,
VP8BatchSize: 4,
Amount: 3,
Mode: testModeSrv,
Link: "direct",
Auth: testAuthProvider,
RoomID: testRoomID,
KeyHex: testCryptoKey,
Transport: "datachannel",
DNSServer: "1.1.1.1:53",
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
SOCKSUser: "u",
SOCKSPass: "p",
VP8FPS: 25,
VP8BatchSize: 4,
LivenessInterval: "2s",
LivenessTimeout: "500ms",
LivenessFailures: 4,
MaxSessionDuration: "6h",
TrafficMaxPayloadSize: 4096,
TrafficMinDelay: "5ms",
TrafficMaxDelay: "30ms",
Amount: 3,
}
if got != want {
t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want)
@@ -121,6 +139,182 @@ func TestApplyCLIWins(t *testing.T) {
}
}
func TestLoadAndApplyProfile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "olcrtc.yaml")
body := `
mode: srv
link: direct
crypto:
key: shared-key
net:
dns: 1.1.1.1:53
liveness:
interval: 5s
timeout: 2s
failures: 5
lifecycle:
max_session_duration: 6h
traffic:
max_payload_size: 8192
min_delay: 10ms
max_delay: 40ms
profiles:
- name: wb-vp8
auth:
provider: wbstream
room:
id: wb-room
net:
transport: vp8channel
vp8:
fps: 30
liveness:
interval: 1s
lifecycle:
max_session_duration: 30m
traffic:
max_payload_size: 4096
max_delay: 20ms
- name: jitsi-dc
auth:
provider: jitsi
room:
id: https://meet.example/room
net:
transport: datachannel
dns: 8.8.8.8:53
failover:
retry_delay: 100ms
max_cycles: 2
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if len(f.Profiles) != 2 {
t.Fatalf("profiles = %d, want 2", len(f.Profiles))
}
if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 {
t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover)
}
base := Apply(session.Config{}, f)
first := ApplyProfile(base, f.Profiles[0])
if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" {
t.Fatalf("first profile = %+v", first)
}
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 ||
first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 ||
first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 ||
first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" {
t.Fatalf("first inherited/overlaid fields = %+v", first)
}
second := ApplyProfile(base, f.Profiles[1])
if second.Auth != "jitsi" || second.Transport != "datachannel" ||
second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" {
t.Fatalf("second profile = %+v", second)
}
if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 ||
second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 ||
second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" {
t.Fatalf("second lifecycle/liveness fields = %+v", second)
}
}
func TestLoadProfileCryptoKeyFile(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
profiles:
- name: file-key
crypto:
key_file: profile.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if got := f.Profiles[0].Crypto.Key; got != testCryptoKey {
t.Fatalf("profile key = %q, want %q", got, testCryptoKey)
}
}
func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "secret.key")
if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
mode: srv
crypto:
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if f.Crypto.Key != testCryptoKey {
t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey)
}
}
func TestLoadCryptoKeyFileConflict(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "olcrtc.yaml")
body := `
crypto:
key: deadbeef
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if !errors.Is(err, ErrCryptoKeyConflict) {
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict)
}
}
func TestLoadCryptoKeyFileEmpty(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "secret.key")
if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
crypto:
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if !errors.Is(err, ErrCryptoKeyFileEmpty) {
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty)
}
}
func TestLoadMissing(t *testing.T) {
_, err := Load(filepath.Join(t.TempDir(), "nope.yaml"))
if err == nil {

343
internal/control/control.go Normal file
View File

@@ -0,0 +1,343 @@
// Package control implements the post-handshake control stream protocol.
//
// The control stream is the first smux stream after the olcrtc handshake. It
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
// tunnel path still round-trips, not merely that the provider connection is up.
//
// Wire format matches the handshake framing: a 4-byte big-endian length
// followed by a JSON message.
//
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
package control
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"sync"
"time"
)
const (
// ProtoVersion identifies the control stream wire format.
ProtoVersion = 1
// MaxMessageSize caps one control frame.
MaxMessageSize = 16 * 1024
// DefaultInterval is the default interval between ping probes.
DefaultInterval = 10 * time.Second
// DefaultTimeout is the default time to wait for a pong.
DefaultTimeout = 5 * time.Second
// DefaultFailures is the default number of consecutive missed pongs before
// the stream is marked unhealthy.
DefaultFailures = 3
)
// MsgType labels a control message.
type MsgType string
const (
// TypePing is sent periodically to prove control-stream liveness.
TypePing MsgType = "CONTROL_PING"
// TypePong replies to a ping with the same sequence and timestamp.
TypePong MsgType = "CONTROL_PONG"
)
var (
// ErrUnhealthy is returned when the stream misses too many pong replies.
ErrUnhealthy = errors.New("control stream unhealthy")
// ErrProtocolVersion is returned when the peer announces an incompatible version.
ErrProtocolVersion = errors.New("incompatible control protocol version")
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
ErrUnexpectedMessage = errors.New("unexpected control message")
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
ErrFrameTooLarge = errors.New("control frame too large")
)
// Message is one control-stream frame.
type Message struct {
Version int `json:"version"`
Type MsgType `json:"type"`
Seq uint64 `json:"seq,omitempty"`
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
}
// Health is reported when a ping round trip completes.
type Health struct {
Seq uint64
RTT time.Duration
LastSeen time.Time
}
// Status is a point-in-time view of control-stream health maintained by
// callers that embed the control loop.
type Status struct {
SessionID string
LastPong time.Time
LastRTT time.Duration
MissedPongs int
Reconnects uint64
UnhealthyEvents uint64
LastUnhealthy time.Time
}
// Config controls the liveness loop.
type Config struct {
Interval time.Duration
Timeout time.Duration
Failures int
// OnPong is called after a matching pong is received.
OnPong func(Health)
// OnMissedPong is called when one or more outstanding pongs time out.
OnMissedPong func(missed int)
// OnUnhealthy is called before Run returns [ErrUnhealthy].
OnUnhealthy func(missed int)
}
func (cfg Config) withDefaults() Config {
if cfg.Interval <= 0 {
cfg.Interval = DefaultInterval
}
if cfg.Timeout <= 0 {
cfg.Timeout = DefaultTimeout
}
if cfg.Failures <= 0 {
cfg.Failures = DefaultFailures
}
return cfg
}
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
// or the configured failure threshold is reached.
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
cfg = cfg.withDefaults()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
state := &state{
rw: rw,
cfg: cfg,
pending: make(map[uint64]time.Time),
now: time.Now,
out: make(chan Message, 16),
}
errCh := make(chan error, 3)
go func() {
<-ctx.Done()
_ = rw.Close()
}()
go func() { errCh <- state.readLoop(ctx) }()
go func() { errCh <- state.probeLoop(ctx) }()
go func() { errCh <- state.writeLoop(ctx) }()
err := <-errCh
cancel()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
return err
}
type state struct {
rw io.ReadWriteCloser
cfg Config
now func() time.Time
out chan Message
mu sync.Mutex
pending map[uint64]time.Time
nextSeq uint64
failures int
}
func (s *state) readLoop(ctx context.Context) error {
for {
raw, err := readFrame(s.rw)
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
msg, err := parseMessage(raw)
if err != nil {
return err
}
switch msg.Type {
case TypePing:
if err := s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePong,
Seq: msg.Seq,
SentUnixNano: msg.SentUnixNano,
}); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
case TypePong:
s.handlePong(msg)
default:
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
}
}
func (s *state) probeLoop(ctx context.Context) error {
ticker := time.NewTicker(s.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if err := s.sendProbe(ctx); err != nil {
return err
}
}
}
}
func (s *state) sendProbe(ctx context.Context) error {
now := s.now()
s.mu.Lock()
missedNow := 0
for seq, sent := range s.pending {
if now.Sub(sent) < s.cfg.Timeout {
continue
}
delete(s.pending, seq)
s.failures++
missedNow++
}
missed := s.failures
if s.failures >= s.cfg.Failures {
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
if s.cfg.OnUnhealthy != nil {
s.cfg.OnUnhealthy(missed)
}
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
}
s.nextSeq++
seq := s.nextSeq
s.pending[seq] = now
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
return s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePing,
Seq: seq,
SentUnixNano: now.UnixNano(),
})
}
func (s *state) handlePong(msg Message) {
now := s.now()
s.mu.Lock()
sent, ok := s.pending[msg.Seq]
if ok {
delete(s.pending, msg.Seq)
s.failures = 0
}
s.mu.Unlock()
if !ok || s.cfg.OnPong == nil {
return
}
s.cfg.OnPong(Health{
Seq: msg.Seq,
RTT: now.Sub(sent),
LastSeen: now,
})
}
func (s *state) enqueue(ctx context.Context, msg Message) error {
select {
case <-ctx.Done():
return ctx.Err()
case s.out <- msg:
return nil
}
}
func (s *state) writeLoop(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-s.out:
if err := writeFrame(s.rw, msg); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
}
}
}
func parseMessage(raw []byte) (Message, error) {
var msg Message
if err := json.Unmarshal(raw, &msg); err != nil {
return Message{}, fmt.Errorf("parse control message: %w", err)
}
if msg.Version != ProtoVersion {
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
ErrProtocolVersion, msg.Version, ProtoVersion)
}
if msg.Type != TypePing && msg.Type != TypePong {
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
return msg, nil
}
func writeFrame(w io.Writer, msg Message) error {
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal control message: %w", err)
}
if len(body) > MaxMessageSize {
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
}
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
if _, err := w.Write(hdr[:]); err != nil {
return fmt.Errorf("write control hdr: %w", err)
}
if _, err := w.Write(body); err != nil {
return fmt.Errorf("write control body: %w", err)
}
return nil
}
func readFrame(r io.Reader) ([]byte, error) {
var hdr [4]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return nil, fmt.Errorf("read control hdr: %w", err)
}
n := binary.BigEndian.Uint32(hdr[:])
if n > MaxMessageSize {
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, fmt.Errorf("read control body: %w", err)
}
return buf, nil
}

View File

@@ -0,0 +1,138 @@
package control
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"testing"
"time"
)
func controlPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
a, b := net.Pipe()
t.Cleanup(func() {
_ = a.Close()
_ = b.Close()
})
return a, b
}
func TestRunPingPongReportsRTT(t *testing.T) {
a, b := controlPair(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
got := make(chan Health, 1)
cfg := Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h Health) {
select {
case got <- h:
default:
}
},
}
errCh := make(chan error, 2)
go func() { errCh <- Run(ctx, a, cfg) }()
go func() { errCh <- Run(ctx, b, cfg) }()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
if h.RTT < 0 {
t.Fatalf("Health.RTT = %v", h.RTT)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for pong health")
}
cancel()
for range 2 {
if err := <-errCh; err != nil {
t.Fatalf("Run() after cancel = %v", err)
}
}
}
func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) {
a, b := controlPair(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_, _ = io.Copy(io.Discard, b)
}()
missedCh := make(chan int, 1)
missedCallbackCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
errCh <- Run(ctx, a, Config{
Interval: 10 * time.Millisecond,
Timeout: 5 * time.Millisecond,
Failures: 2,
OnMissedPong: func(missed int) {
select {
case missedCallbackCh <- missed:
default:
}
},
OnUnhealthy: func(missed int) { missedCh <- missed },
})
}()
select {
case err := <-errCh:
if !errors.Is(err, ErrUnhealthy) {
t.Fatalf("Run() error = %v, want ErrUnhealthy", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for unhealthy result")
}
if missed := <-missedCh; missed < 2 {
t.Fatalf("missed = %d, want >= 2", missed)
}
if missed := <-missedCallbackCh; missed < 1 {
t.Fatalf("missed callback = %d, want >= 1", missed)
}
}
func TestRunRejectsBadProtocolVersion(t *testing.T) {
a, b := controlPair(t)
errCh := make(chan error, 1)
go func() {
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
}()
if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil {
t.Fatalf("writeFrame() error = %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, ErrProtocolVersion) {
t.Fatalf("Run() error = %v, want ErrProtocolVersion", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for protocol error")
}
}
func TestReadFrameRejectsTooLarge(t *testing.T) {
a, b := controlPair(t)
go func() {
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1)
_, _ = b.Write(hdr[:])
}()
_, err := readFrame(a)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err)
}
}

View File

@@ -10,6 +10,9 @@ import (
"golang.org/x/crypto/chacha20poly1305"
)
// WireOverhead is the number of bytes added to each encrypted message.
const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead
var (
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
ErrInvalidKeySize = errors.New("invalid key size")

View File

@@ -24,6 +24,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/server"
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
"github.com/pion/webrtc/v4"
)
@@ -47,6 +48,7 @@ var (
errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply")
errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting")
errPayloadMismatchOffset = errors.New("payload mismatch at offset")
errFailoverCarrier = errors.New("intentional failover carrier failure")
)
var (
@@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) {
})
}
func registerFailingCarrier(t *testing.T) string {
t.Helper()
session.RegisterDefaults()
name := "e2e-fail-" + t.Name()
carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) {
return nil, errFailoverCarrier
})
return name
}
func builtInCarrierNames() []string {
return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional
}
@@ -1008,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
if err := ln.Connect(context.Background()); err != nil {
t.Fatalf("Connect() error = %v", err)
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
assertLinkCanSendAfterConnect(t, ln, transportName)
if err := ln.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
@@ -1020,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
}
}
func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) {
t.Helper()
if transportName == transportSEI {
if ln.CanSend() {
t.Fatal("CanSend() = true before peer seichannel frame")
}
return
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
}
//nolint:cyclop // table-driven test naturally has many branches
func TestRealProviderTransportMatrix(t *testing.T) {
if !*realE2E {
@@ -1163,6 +1188,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
}
}
func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) {
echoAddr := startEchoServer(t)
failingCarrier := registerFailingCarrier(t)
memoryCarrier, room := registerMemoryCarrier(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
socksAddr := freeLocalAddr(ctx, t)
socksHost, socksPort := splitHostPort(t, socksAddr)
serverProfiles := []supervisor.Profile{
{Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)},
{Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)},
}
clientProfiles := []supervisor.Profile{
{Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)},
{Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)},
}
started := make(chan string, 8)
serverErr := make(chan error, 1)
go func() {
serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run)
}()
room.waitConnected(t, 1)
ready := make(chan struct{})
var readyOnce sync.Once
clientErr := make(chan error, 1)
go func() {
clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error {
return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() {
if cfg.Auth == memoryCarrier {
readyOnce.Do(func() { close(ready) })
}
})
})
}()
waitForReady(t, ready)
conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr)
defer func() { _ = conn.Close() }()
payload := []byte("olcrtc-failover-e2e\n")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write failover payload: %v", err)
}
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
t.Fatalf("set failover read deadline: %v", err)
}
line, err := bufio.NewReader(conn).ReadBytes('\n')
if err != nil {
t.Fatalf("read failover echo: %v", err)
}
if !bytes.Equal(line, payload) {
t.Fatalf("failover echo = %q, want %q", line, payload)
}
requireStartedProfiles(t, started, []string{
"server:failing-server",
"server:memory-server",
"client:failing-client",
"client:memory-client",
})
cancel()
waitSupervisorStopped(t, "client", clientErr)
waitSupervisorStopped(t, "server", serverErr)
}
func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config {
cfg := session.Config{
Mode: mode,
Link: linkDirect,
Transport: transportData,
Auth: carrierName,
RoomID: testRoom,
KeyHex: testKeyHex,
DNSServer: localDNSServer,
}
if mode == "cnc" {
cfg.SOCKSHost = socksHost
cfg.SOCKSPort = socksPort
}
return cfg
}
func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config {
return client.Config{
Link: cfg.Link,
Transport: cfg.Transport,
Carrier: cfg.Auth,
RoomURL: cfg.RoomID,
KeyHex: cfg.KeyHex,
LocalAddr: socksAddr,
DNSServer: cfg.DNSServer,
DeviceID: testClientDeviceID,
VideoWidth: cfg.VideoWidth,
VideoHeight: cfg.VideoHeight,
VideoFPS: cfg.VideoFPS,
VideoBitrate: cfg.VideoBitrate,
VideoHW: cfg.VideoHW,
VideoQRSize: cfg.VideoQRSize,
VideoQRRecovery: cfg.VideoQRRecovery,
VideoCodec: cfg.VideoCodec,
VideoTileModule: cfg.VideoTileModule,
VideoTileRS: cfg.VideoTileRS,
VP8FPS: cfg.VP8FPS,
VP8BatchSize: cfg.VP8BatchSize,
SEIFPS: cfg.SEIFPS,
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
}
}
func failoverE2EConfig(
profiles []supervisor.Profile,
started chan<- string,
side string,
) supervisor.Config {
return supervisor.Config{
Profiles: profiles,
RetryDelay: time.Millisecond,
OnProfileStart: func(profile supervisor.Profile, _ int) {
select {
case started <- side + ":" + profile.Name:
default:
}
},
}
}
func splitHostPort(t *testing.T, addr string) (string, int) {
t.Helper()
host, portText, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split host port %q: %v", addr, err)
}
port, err := strconv.Atoi(portText)
if err != nil {
t.Fatalf("parse port %q: %v", portText, err)
}
return host, port
}
func requireStartedProfiles(t *testing.T, started <-chan string, want []string) {
t.Helper()
seen := make(map[string]bool)
deadline := time.After(3 * time.Second)
for len(seen) < len(want) {
select {
case item := <-started:
seen[item] = true
case <-deadline:
t.Fatalf("started profiles = %v, want all %v", seen, want)
}
}
for _, item := range want {
if !seen[item] {
t.Fatalf("started profiles = %v, missing %s", seen, item)
}
}
}
func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) {
t.Helper()
select {
case err := <-ch:
if err != nil {
t.Fatalf("%s supervisor returned error: %v", name, err)
}
case <-time.After(3 * time.Second):
t.Fatalf("%s supervisor did not stop", name)
}
}
func TestEndedCallbackStopsClientAndServer(t *testing.T) {
rt := startTunnel(t)
rt.room.triggerEnded("conference ended")

View File

@@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error {
}
func (s *Session) dialWebSocket() error {
wsDialer := websocket.Dialer{
NetDialContext: protect.DialContext,
HandshakeTimeout: wsHandshakeTimeout,
}
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil)
if err != nil {
return fmt.Errorf("dial ws: %w", err)

View File

@@ -19,13 +19,17 @@ import (
protoLogger "github.com/livekit/protocol/logger"
lksdk "github.com/livekit/server-sdk-go/v2"
"github.com/openlibrecommunity/olcrtc/internal/engine"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/pion/webrtc/v4"
)
const (
defaultSendQueueSize = 5000
dataPublishTopic = "olcrtc"
videoTrackName = "videochannel"
defaultSendQueueSize = 5000
defaultSendQueueCapHard = 4000
dataPublishTopic = "olcrtc"
videoTrackName = "videochannel"
reconnectWindow = 5 * time.Minute
maxReconnects = 10
)
var (
@@ -41,20 +45,98 @@ var (
ErrTokenRequired = errors.New("livekit access token required")
)
type roomHandle interface {
publishData([]byte) error
publishTrack(webrtc.TrackLocal) error
unpublishLocalTracks()
disconnect()
connectionState() lksdk.ConnectionState
}
type sdkRoom struct {
room *lksdk.Room
}
func (r *sdkRoom) publishData(data []byte) error {
return r.room.LocalParticipant.PublishDataPacket(
lksdk.UserData(data),
lksdk.WithDataPublishTopic(dataPublishTopic),
lksdk.WithDataPublishReliable(true),
)
}
func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error {
_, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName})
return err
}
func (r *sdkRoom) unpublishLocalTracks() {
if r.room == nil || r.room.LocalParticipant == nil {
return
}
for _, publication := range r.room.LocalParticipant.TrackPublications() {
if publication.SID() == "" {
continue
}
if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
log.Printf("livekit unpublish track error: %v", err)
}
}
}
func (r *sdkRoom) disconnect() {
r.room.Disconnect()
// LiveKit's Disconnect returns after local SDK teardown, before the
// server necessarily evicts the participant. Give the signalling path a
// short grace period so immediate reconnects do not inherit stale room
// state from a ghost participant.
time.Sleep(2 * time.Second)
}
func (r *sdkRoom) connectionState() lksdk.ConnectionState {
return r.room.ConnectionState()
}
type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error)
func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) {
room, err := lksdk.ConnectToRoomWithToken(
url,
token,
callback,
lksdk.WithAutoSubscribe(true),
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
)
if err != nil {
return nil, err
}
return &sdkRoom{room: room}, nil
}
// Session is the LiveKit engine handle.
type Session struct {
url string
token string
name string
room *lksdk.Room
refresh func(ctx context.Context) (engine.Credentials, error)
connectRoom connectRoomFunc
room roomHandle
roomMu sync.RWMutex
onData func([]byte)
onReconnect func(*webrtc.DataChannel)
shouldReconnect func() bool
onEnded func(string)
reconnectCh chan struct{}
closeCh chan struct{}
lastReconnect time.Time
reconnectCount int
sendQueue chan []byte
closed atomic.Bool
reconnecting atomic.Bool
done chan struct{}
cancel context.CancelFunc
shutdownOnce sync.Once
sendWorkerOnce sync.Once
videoTrackMu sync.RWMutex
videoTracks []webrtc.TrackLocal
onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
@@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) {
}
_, cancel := context.WithCancel(ctx)
return &Session{
url: cfg.URL,
token: cfg.Token,
name: cfg.Name,
onData: cfg.OnData,
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
cancel: cancel,
url: cfg.URL,
token: cfg.Token,
name: cfg.Name,
refresh: cfg.Refresh,
connectRoom: connectSDKRoom,
onData: cfg.OnData,
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
cancel: cancel,
}, nil
}
@@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities {
}
// Connect joins the LiveKit room.
func (s *Session) Connect(_ context.Context) error {
func (s *Session) Connect(ctx context.Context) error {
s.closed.Store(false)
if err := s.connectSession(ctx); err != nil {
return err
}
s.startSendWorker()
return nil
}
func (s *Session) connectSession(_ context.Context) error {
roomCB := &lksdk.RoomCallback{
ParticipantCallback: lksdk.ParticipantCallback{
OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) {
@@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error {
},
},
OnDisconnected: func() {
if !s.closed.Load() && s.onEnded != nil {
s.onEnded("disconnected from livekit")
if s.closed.Load() || s.reconnecting.Load() {
return
}
if !s.queueReconnect() {
s.signalEnded("disconnected from livekit")
}
},
}
room, err := lksdk.ConnectToRoomWithToken(
s.url,
s.token,
roomCB,
lksdk.WithAutoSubscribe(true),
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
)
room, err := s.connectRoom(s.url, s.token, roomCB)
if err != nil {
return fmt.Errorf("connect to room: %w", err)
}
s.room = room
s.setRoom(room)
if err := s.publishPendingTracks(); err != nil {
return err
}
s.wg.Add(1)
go s.processSendQueue()
return nil
}
func (s *Session) publishPendingTracks() error {
room := s.currentRoom()
if room == nil {
return ErrRoomNotConnected
}
s.videoTrackMu.RLock()
defer s.videoTrackMu.RUnlock()
for _, track := range s.videoTracks {
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: videoTrackName,
}); err != nil {
if err := room.publishTrack(track); err != nil {
return fmt.Errorf("failed to publish track: %w", err)
}
}
return nil
}
func (s *Session) startSendWorker() {
s.sendWorkerOnce.Do(func() {
s.wg.Add(1)
go s.processSendQueue()
})
}
func (s *Session) processSendQueue() {
defer s.wg.Done()
for {
@@ -157,17 +256,33 @@ func (s *Session) processSendQueue() {
if !ok {
return
}
if err := s.room.LocalParticipant.PublishDataPacket(
lksdk.UserData(data),
lksdk.WithDataPublishTopic(dataPublishTopic),
lksdk.WithDataPublishReliable(true),
); err != nil {
room := s.waitForConnectedRoom()
if room == nil {
return
}
if err := room.publishData(data); err != nil {
log.Printf("livekit publish data error: %v", err)
}
}
}
}
func (s *Session) waitForConnectedRoom() roomHandle {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
room := s.currentRoom()
if room != nil && room.connectionState() == lksdk.ConnectionStateConnected {
return room
}
select {
case <-s.done:
return nil
case <-ticker.C:
}
}
}
// Send queues data for transmission.
func (s *Session) Send(data []byte) error {
if s.closed.Load() {
@@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error {
// Close terminates the session.
func (s *Session) Close() error {
if s.closed.CompareAndSwap(false, true) {
s.cancel()
close(s.done)
if s.room != nil {
s.unpublishLocalTracks()
s.room.Disconnect()
// LiveKit's Disconnect() returns once the local SDK state
// is torn down, not when the server has actually evicted
// the participant. Without giving the signalling channel
// time to flush the LEAVE_REQUEST and the server to act on
// it, a back-to-back reconnect from the same identity in
// the same room sees a still-alive ghost participant on
// the SFU and inherits stale publication state.
time.Sleep(2 * time.Second)
}
close(s.sendQueue)
s.wg.Wait()
}
s.closed.Store(true)
s.shutdown()
return nil
}
func (s *Session) unpublishLocalTracks() {
if s.room == nil || s.room.LocalParticipant == nil {
return
}
for _, publication := range s.room.LocalParticipant.TrackPublications() {
if publication.SID() == "" {
continue
func (s *Session) shutdown() {
s.shutdownOnce.Do(func() {
if s.cancel != nil {
s.cancel()
}
if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
log.Printf("livekit unpublish track error: %v", err)
closeSignal(s.closeCh)
closeSignal(s.done)
if room := s.swapRoom(nil); room != nil {
room.unpublishLocalTracks()
room.disconnect()
}
}
s.wg.Wait()
})
}
// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity).
// SetReconnectCallback stores the reconnect callback.
func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb }
// SetShouldReconnect stores the reconnect predicate (kept for API parity).
// SetShouldReconnect stores the reconnect predicate.
func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn }
// SetEndedCallback registers a function to call when the session ends.
func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb }
// WatchConnection is a no-op; LiveKit handles connection supervision itself.
func (s *Session) WatchConnection(_ context.Context) {}
// WatchConnection monitors the connection lifecycle and reconnects as needed.
func (s *Session) WatchConnection(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-s.closeCh:
return
case <-s.reconnectCh:
if s.handleReconnectAttempt(ctx) {
return
}
}
}
}
func (s *Session) handleReconnectAttempt(ctx context.Context) bool {
if time.Since(s.lastReconnect) > reconnectWindow {
s.reconnectCount = 0
}
s.reconnectCount++
s.lastReconnect = time.Now()
if s.reconnectCount > maxReconnects {
s.signalEnded("reconnect limit reached")
return true
}
backoff := time.Duration(s.reconnectCount) * 2 * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
for {
if err := s.reconnect(ctx); err != nil {
logger.Debugf("livekit reconnect failed: %v", err)
select {
case <-ctx.Done():
return true
case <-s.closeCh:
return true
case <-time.After(backoff):
continue
}
}
s.drainReconnectQueue()
return false
}
}
func (s *Session) reconnect(ctx context.Context) error {
s.reconnecting.Store(true)
defer s.reconnecting.Store(false)
if room := s.swapRoom(nil); room != nil {
room.unpublishLocalTracks()
room.disconnect()
}
if s.refresh != nil {
creds, err := s.refresh(ctx)
if err != nil {
return fmt.Errorf("refresh credentials: %w", err)
}
s.applyRefreshedCredentials(creds)
}
if err := s.connectSession(ctx); err != nil {
return err
}
if s.onReconnect != nil {
s.onReconnect(nil)
}
return nil
}
func (s *Session) applyRefreshedCredentials(creds engine.Credentials) {
if creds.URL != "" {
s.url = creds.URL
}
if creds.Token != "" {
s.token = creds.Token
}
}
func (s *Session) queueReconnect() bool {
if s.closed.Load() || s.reconnecting.Load() {
return false
}
if s.shouldReconnect != nil && !s.shouldReconnect() {
return false
}
select {
case s.reconnectCh <- struct{}{}:
default:
}
return true
}
func (s *Session) drainReconnectQueue() {
for {
select {
case <-s.reconnectCh:
default:
return
}
}
}
func (s *Session) signalEnded(reason string) {
s.closed.Store(true)
s.shutdown()
if s.onEnded != nil {
s.onEnded(reason)
}
}
// CanSend reports whether the session is ready to accept data.
func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil }
func (s *Session) CanSend() bool {
if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard {
return false
}
room := s.currentRoom()
return room != nil && room.connectionState() == lksdk.ConnectionStateConnected
}
// GetSendQueue exposes the outbound queue.
func (s *Session) GetSendQueue() chan []byte { return s.sendQueue }
@@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error {
s.videoTracks = append(s.videoTracks, track)
s.videoTrackMu.Unlock()
if s.room == nil || s.room.LocalParticipant == nil {
room := s.currentRoom()
if room == nil {
return nil
}
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: videoTrackName,
}); err != nil {
if err := room.publishTrack(track); err != nil {
return fmt.Errorf("failed to publish track: %w", err)
}
return nil
@@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR
s.onVideoTrack = cb
}
func (s *Session) currentRoom() roomHandle {
s.roomMu.RLock()
defer s.roomMu.RUnlock()
return s.room
}
func (s *Session) setRoom(room roomHandle) {
s.roomMu.Lock()
defer s.roomMu.Unlock()
s.room = room
}
func (s *Session) swapRoom(room roomHandle) roomHandle {
s.roomMu.Lock()
defer s.roomMu.Unlock()
old := s.room
s.room = room
return old
}
func closeSignal(ch chan struct{}) {
select {
case <-ch:
default:
close(ch)
}
}
func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins
engine.Register("livekit", New)
}

View File

@@ -0,0 +1,306 @@
package livekit
import (
"context"
"errors"
"sync"
"testing"
"time"
lksdk "github.com/livekit/server-sdk-go/v2"
"github.com/openlibrecommunity/olcrtc/internal/engine"
"github.com/pion/webrtc/v4"
)
type fakeRoom struct {
mu sync.Mutex
state lksdk.ConnectionState
published [][]byte
tracks int
unpublished int
disconnected int
}
func newFakeRoom() *fakeRoom {
return &fakeRoom{state: lksdk.ConnectionStateConnected}
}
func (r *fakeRoom) publishData(data []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
r.published = append(r.published, append([]byte(nil), data...))
return nil
}
func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error {
r.mu.Lock()
defer r.mu.Unlock()
r.tracks++
return nil
}
func (r *fakeRoom) unpublishLocalTracks() {
r.mu.Lock()
defer r.mu.Unlock()
r.unpublished++
}
func (r *fakeRoom) disconnect() {
r.mu.Lock()
defer r.mu.Unlock()
r.disconnected++
r.state = lksdk.ConnectionStateDisconnected
}
func (r *fakeRoom) connectionState() lksdk.ConnectionState {
r.mu.Lock()
defer r.mu.Unlock()
return r.state
}
type fakeConnector struct {
mu sync.Mutex
urls []string
tokens []string
callbacks []*lksdk.RoomCallback
rooms []*fakeRoom
connected chan struct{}
err error
}
func newFakeConnector() *fakeConnector {
return &fakeConnector{connected: make(chan struct{}, 8)}
}
func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
return nil, c.err
}
room := newFakeRoom()
c.urls = append(c.urls, url)
c.tokens = append(c.tokens, token)
c.callbacks = append(c.callbacks, cb)
c.rooms = append(c.rooms, room)
c.connected <- struct{}{}
return room, nil
}
func (c *fakeConnector) count() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.rooms)
}
func (c *fakeConnector) callback(i int) *lksdk.RoomCallback {
c.mu.Lock()
defer c.mu.Unlock()
return c.callbacks[i]
}
func (c *fakeConnector) room(i int) *fakeRoom {
c.mu.Lock()
defer c.mu.Unlock()
return c.rooms[i]
}
func (c *fakeConnector) snapshot() ([]string, []string) {
c.mu.Lock()
defer c.mu.Unlock()
return append([]string(nil), c.urls...), append([]string(nil), c.tokens...)
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("condition was not met before timeout")
}
func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
refreshes := 0
sess, err := New(ctx, engine.Config{
URL: "wss://old",
Token: "old-token",
Refresh: func(context.Context) (engine.Credentials, error) {
refreshes++
return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil
},
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
s := sess.(*Session)
connector := newFakeConnector()
s.connectRoom = connector.connect
reconnected := make(chan struct{}, 1)
s.SetReconnectCallback(func(*webrtc.DataChannel) {
reconnected <- struct{}{}
})
if err := s.Connect(ctx); err != nil {
t.Fatalf("Connect() error = %v", err)
}
go s.WatchConnection(ctx)
connector.callback(0).OnDisconnected()
waitFor(t, func() bool { return connector.count() == 2 })
select {
case <-reconnected:
case <-time.After(time.Second):
t.Fatal("reconnect callback was not called")
}
urls, tokens := connector.snapshot()
if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) {
t.Fatalf("connect urls = %v, want %v", got, want)
}
if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) {
t.Fatalf("connect tokens = %v, want %v", got, want)
}
if refreshes != 1 {
t.Fatalf("refreshes = %d, want 1", refreshes)
}
oldRoom := connector.room(0)
oldRoom.mu.Lock()
if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 {
t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1",
oldRoom.disconnected, oldRoom.unpublished)
}
oldRoom.mu.Unlock()
if !s.CanSend() {
t.Fatal("CanSend() = false after reconnect, want true")
}
if err := s.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
}
func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) {
ctx := context.Background()
sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
s := sess.(*Session)
connector := newFakeConnector()
s.connectRoom = connector.connect
s.SetShouldReconnect(func() bool { return false })
ended := make(chan string, 1)
s.SetEndedCallback(func(reason string) {
ended <- reason
})
if err := s.Connect(ctx); err != nil {
t.Fatalf("Connect() error = %v", err)
}
connector.callback(0).OnDisconnected()
select {
case reason := <-ended:
if reason != "disconnected from livekit" {
t.Fatalf("ended reason = %q, want disconnected from livekit", reason)
}
case <-time.After(time.Second):
t.Fatal("ended callback was not called")
}
if !s.closed.Load() {
t.Fatal("closed = false after terminal disconnect")
}
if connector.count() != 1 {
t.Fatalf("connect count = %d, want 1", connector.count())
}
room := connector.room(0)
room.mu.Lock()
if room.disconnected != 1 || room.unpublished != 1 {
t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1",
room.disconnected, room.unpublished)
}
room.mu.Unlock()
if err := s.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
room.mu.Lock()
if room.disconnected != 1 || room.unpublished != 1 {
t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1",
room.disconnected, room.unpublished)
}
room.mu.Unlock()
}
func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) {
s := &Session{
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
closeCh: make(chan struct{}),
}
if s.CanSend() {
t.Fatal("CanSend() = true without room")
}
room := newFakeRoom()
room.state = lksdk.ConnectionStateDisconnected
s.setRoom(room)
if s.CanSend() {
t.Fatal("CanSend() = true for disconnected room")
}
room.state = lksdk.ConnectionStateConnected
if !s.CanSend() {
t.Fatal("CanSend() = false for connected room")
}
for i := 0; i < defaultSendQueueCapHard; i++ {
s.sendQueue <- []byte("x")
}
if s.CanSend() {
t.Fatal("CanSend() = true at queue high watermark")
}
}
func TestReconnectFailureRetriesUntilContextDone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Session{
url: "wss://old",
token: "old-token",
connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) {
cancel()
return nil, errors.New("boom")
},
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
}
if terminal := s.handleReconnectAttempt(ctx); !terminal {
t.Fatal("handleReconnectAttempt() = false after context cancellation")
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration)
}
func (s *Session) dialWebSocket() error {
wsDialer := websocket.Dialer{
NetDialContext: protect.DialContext,
HandshakeTimeout: wsHandshakeTimeout,
}
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
ws, resp, err := wsDialer.Dial(s.connectorURL, nil)
if err != nil {

View File

@@ -13,8 +13,8 @@
// │ │
//
// After the exchange the control stream stays open; tunnel traffic flows over
// additional smux streams opened by the client. The control stream may carry
// keepalives or future control messages.
// additional smux streams opened by the client. The control stream then
// carries ping/pong liveness and future control messages.
//
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
package handshake

View File

@@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return nil, fmt.Errorf("create transport for direct link: %w", err)
@@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) {
d.transport.WatchConnection(ctx)
}
func (d *directLink) CanSend() bool { return d.transport.CanSend() }
// Features reports the direct link's underlying transport capabilities.
func (d *directLink) Features() link.Features { return d.transport.Features() }

View File

@@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
VideoTileRS: 20,
VP8FPS: 25,
VP8BatchSize: 8,
Traffic: transport.TrafficConfig{MaxPayloadSize: 4096},
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 ||
seen.Traffic.MaxPayloadSize != 4096 {
t.Fatalf("forwarded config = %+v", seen)
}
@@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 {
t.Fatalf("Features() = %+v, want shaped max payload 4096", features)
}
}
func TestNewWrapsFactoryError(t *testing.T) {

View File

@@ -4,6 +4,8 @@ package link
import (
"context"
"errors"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
var (
@@ -23,11 +25,19 @@ type Link interface {
CanSend() bool
}
// Features mirrors the underlying transport capabilities when a link can expose them.
type Features = transport.Features
// FeaturesProvider is optionally implemented by links that can report wire limits.
type FeaturesProvider interface {
Features() Features
}
// Config holds common link configuration.
type Config struct {
Transport string
Carrier string
RoomURL string
Transport string
Carrier string
RoomURL string
// Engine, URL, Token are forwarded for the "none" auth carrier.
Engine string
URL string
@@ -54,6 +64,7 @@ type Config struct {
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Traffic transport.TrafficConfig
}
// Factory creates a link instance.

View File

@@ -3,13 +3,38 @@ package protect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"regexp"
"strings"
"syscall"
"time"
"github.com/gorilla/websocket"
)
const (
defaultDialTimeout = 10 * time.Second
defaultKeepAlive = 30 * time.Second
defaultIdleConnTimeout = 30 * time.Second
defaultTLSHandshake = 10 * time.Second
defaultResponseHeader = 10 * time.Second
defaultWebSocketTimeout = 10 * time.Second
defaultHTTPClientTimeout = 30 * time.Second
defaultStatusBodyLimit = 1024
)
var (
sensitiveFieldRE = regexp.MustCompile(
`(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` +
`[^",\s}]+`,
)
sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`)
) //nolint:gochecknoglobals // compiled once for provider error redaction
// Protector is called with a socket file descriptor before connect.
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional
@@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error {
// NewDialer returns a net.Dialer that calls Protector on each new socket.
func NewDialer() *net.Dialer {
return &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Timeout: defaultDialTimeout,
KeepAlive: defaultKeepAlive,
Control: controlFunc,
}
}
// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients.
func NewTLSConfig() *tls.Config {
return &tls.Config{MinVersion: tls.VersionTLS12}
}
// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts.
func NewHTTPTransport() *http.Transport {
dialer := NewDialer()
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
TLSClientConfig: NewTLSConfig(),
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: defaultIdleConnTimeout,
TLSHandshakeTimeout: defaultTLSHandshake,
ResponseHeaderTimeout: defaultResponseHeader,
}
}
// NewHTTPClient returns an http.Client using protected sockets.
func NewHTTPClient() *http.Client {
dialer := NewDialer()
transport := &http.Transport{
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
return &http.Client{
Transport: NewHTTPTransport(),
Timeout: defaultHTTPClientTimeout,
}
return &http.Client{Transport: transport}
}
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
if handshakeTimeout <= 0 {
handshakeTimeout = defaultWebSocketTimeout
}
return websocket.Dialer{
NetDialContext: DialContext,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: NewTLSConfig(),
HandshakeTimeout: handshakeTimeout,
}
}
// StatusError formats an upstream HTTP error while bounding and redacting the body.
func StatusError(base error, resp *http.Response, limit int64) error {
if limit <= 0 {
limit = defaultStatusBodyLimit
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
bodyText := RedactSensitive(strings.TrimSpace(string(body)))
if bodyText == "" {
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
}
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
}
// RedactSensitive removes common token-like values from provider error text.
func RedactSensitive(text string) string {
text = sensitiveBearerRE.ReplaceAllString(text, "${1}<redacted>")
return sensitiveFieldRE.ReplaceAllString(text, "${1}<redacted>")
}
// DialContext dials using a protected socket.

View File

@@ -2,9 +2,11 @@ package protect
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
"syscall"
"testing"
"time"
@@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
if !ok {
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
}
if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second ||
tr.ResponseHeaderTimeout != 10*time.Second {
tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second {
t.Fatalf("transport = %+v", tr)
}
}
func TestNewWebSocketDialer(t *testing.T) {
dialer := NewWebSocketDialer(3 * time.Second)
if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil ||
dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 ||
dialer.HandshakeTimeout != 3*time.Second {
t.Fatalf("NewWebSocketDialer() = %+v", dialer)
}
defaulted := NewWebSocketDialer(0)
if defaulted.HandshakeTimeout != defaultWebSocketTimeout {
t.Fatalf("default HandshakeTimeout = %v, want %v",
defaulted.HandshakeTimeout, defaultWebSocketTimeout)
}
}
func TestStatusErrorRedactsAndLimitsBody(t *testing.T) {
resp := &http.Response{
StatusCode: http.StatusForbidden,
Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)},
}
err := StatusError(errProtectBoom, resp, 1024)
if err == nil {
t.Fatal("StatusError() error = nil")
}
text := err.Error()
if strings.Contains(text, "secret") || !strings.Contains(text, "<redacted>") {
t.Fatalf("StatusError() = %q, want redacted token", text)
}
}
func TestRedactSensitiveBearer(t *testing.T) {
got := RedactSensitive("Authorization: Bearer abc.def")
if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer <redacted>") {
t.Fatalf("RedactSensitive() = %q", got)
}
}
type ioNopCloser struct {
*strings.Reader
}
func (c ioNopCloser) Close() error { return nil }
func TestDialContextAndProxyDialer(t *testing.T) {
var lc net.ListenConfig
ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0")

View File

@@ -14,12 +14,14 @@ import (
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
@@ -49,25 +51,33 @@ type SessionCloseFunc func(sessionID, reason string)
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
// HealthFunc is called when the server control health snapshot changes.
type HealthFunc func(control.Status)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStop context.CancelFunc
sessMu sync.RWMutex
reinstallMu sync.Mutex
healthMu sync.RWMutex
wg sync.WaitGroup
authHook handshake.AuthFunc
onOpen SessionOpenFunc
onClose SessionCloseFunc
onTraffic TrafficFunc
onHealth HealthFunc
deviceID string
sessionID string
dnsServer string
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
liveness control.Config
health control.Status
}
// ConnectRequest is a message from the client to establish a new connection.
@@ -106,6 +116,8 @@ type Config struct {
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
@@ -117,6 +129,8 @@ type Config struct {
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
OnTraffic TrafficFunc
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the server with the given configuration.
@@ -145,6 +159,10 @@ func Run(ctx context.Context, cfg Config) error {
if onTraffic == nil {
onTraffic = func(string, string, uint64, uint64) {}
}
onHealth := cfg.OnHealth
if onHealth == nil {
onHealth = func(control.Status) {}
}
s := &Server{
cipher: cipher,
@@ -152,9 +170,11 @@ func Run(ctx context.Context, cfg Config) error {
onOpen: onOpen,
onClose: onClose,
onTraffic: onTraffic,
onHealth: onHealth,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
liveness: cfg.Liveness,
}
s.setupResolver()
@@ -216,11 +236,17 @@ func (s *Server) setupResolver() {
// smuxConfig mirrors the client side. Both peers must agree on Version and
// MaxFrameSize.
func smuxConfig() *smux.Config {
func smuxConfig(maxWirePayload ...int) *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
if maxFrameSize < cfg.MaxFrameSize {
cfg.MaxFrameSize = maxFrameSize
}
}
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
@@ -228,6 +254,14 @@ func smuxConfig() *smux.Config {
return cfg
}
func linkMaxPayload(ln link.Link) int {
provider, ok := ln.(link.FeaturesProvider)
if !ok {
return 0
}
return provider.Features().MaxPayloadSize
}
func (s *Server) bringUpLink(
ctx context.Context,
cfg Config,
@@ -262,6 +296,7 @@ func (s *Server) bringUpLink(
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
@@ -298,7 +333,7 @@ func (s *Server) bringUpLink(
func (s *Server) installSession() {
conn := muxconn.New(s.ln, s.cipher)
sess, err := smux.Server(conn, smuxConfig())
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
return
@@ -310,7 +345,8 @@ func (s *Server) installSession() {
}
func (s *Server) handleReconnect() {
logger.Infof("server link reconnect - tearing down smux session")
s.recordReconnect()
logger.Infof("server reconnect reason=carrier - tearing down smux session")
s.sessMu.RLock()
current := s.session
s.sessMu.RUnlock()
@@ -323,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) {
// Pre-build the replacement so we can swap atomically below.
newConn := muxconn.New(s.ln, s.cipher)
newSess, err := smux.Server(newConn, smuxConfig())
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
_ = newConn.Close()
@@ -340,13 +376,18 @@ func (s *Server) reinstallSession(dead *smux.Session) {
}
oldSess := s.session
oldConn := s.conn
oldControlStop := s.controlStop
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.controlStop = nil
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldSess != nil {
_ = oldSess.Close()
}
@@ -362,13 +403,18 @@ func (s *Server) closeSession() {
s.sessMu.Lock()
sess := s.session
conn := s.conn
controlStop := s.controlStop
s.session = nil
s.conn = nil
s.controlStop = nil
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if conn != nil {
_ = conn.Close()
}
@@ -476,27 +522,120 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
// The control stream stays open for the lifetime of the session;
// keep it parked in a goroutine so the smux session does not close it.
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.parkControlStream(stream)
}()
s.startControlLoop(ctx, sess, stream)
return true
}
// parkControlStream blocks reading from the control stream until it closes.
// Future control messages (kick, rate updates, etc.) would be dispatched here.
func (s *Server) parkControlStream(stream *smux.Stream) {
defer func() { _ = stream.Close() }()
buf := make([]byte, 64)
for {
if _, err := stream.Read(buf); err != nil {
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
s.sessMu.Lock()
s.controlStop = stop
s.sessMu.Unlock()
liveness := s.liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
s.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
s.recordMissed(missed)
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
s.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("server control stream ended: %v", err)
}
s.recordReconnect()
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
s.reinstallSession(sess)
}()
}
// Status returns the latest server-side control health snapshot.
func (s *Server) Status() control.Status {
s.healthMu.RLock()
defer s.healthMu.RUnlock()
return s.health
}
func (s *Server) recordSession(sessionID string) {
s.healthMu.Lock()
s.health.SessionID = sessionID
s.health.MissedPongs = 0
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordPong(h control.Health) {
s.healthMu.Lock()
s.health.LastPong = h.LastSeen
s.health.LastRTT = h.RTT
s.health.MissedPongs = 0
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordMissed(missed int) {
s.healthMu.Lock()
s.health.MissedPongs = missed
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordUnhealthy(missed int) {
s.healthMu.Lock()
s.health.MissedPongs = missed
s.health.UnhealthyEvents++
s.health.LastUnhealthy = time.Now()
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordReconnect() {
s.healthMu.Lock()
s.health.Reconnects++
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) notifyHealth(status control.Status) {
if s.onHealth != nil {
s.onHealth(status)
}
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
@@ -49,6 +50,11 @@ func TestSmuxConfig(t *testing.T) {
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
capped := smuxConfig(4096)
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
}
}
func TestParseConnectRequest(t *testing.T) {
@@ -373,6 +379,103 @@ func TestReinstallSessionFiresOnClose(t *testing.T) {
}
}
func TestStartControlLoopReportsPong(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
serverStreamCh := make(chan *smux.Stream, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
serverStreamCh <- stream
}
}()
clientStream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
serverStream := <-serverStreamCh
ctx, cancel := context.WithCancel(context.Background())
got := make(chan control.Health, 1)
s := &Server{
sessionID: "sid-control",
liveness: control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h control.Health) {
select {
case got <- h:
default:
}
},
},
}
s.recordSession("sid-control")
defer func() {
cancel()
s.wg.Wait()
}()
s.startControlLoop(ctx, serverSess, serverStream)
go func() {
_ = control.Run(ctx, clientStream, control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
})
}()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for control pong")
}
status := s.Status()
if status.SessionID != "sid-control" {
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
}
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
t.Fatalf("Status() = %+v", status)
}
}
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
updates := 0
s := &Server{onHealth: func(control.Status) { updates++ }}
s.recordSession("sid-1")
s.recordMissed(2)
s.recordUnhealthy(3)
s.recordReconnect()
status := s.Status()
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
t.Fatalf("Status() = %+v", status)
}
if updates != 4 {
t.Fatalf("health updates = %d, want 4", updates)
}
}
//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together.
func TestDispatchFiresOnTraffic(t *testing.T) {
var lc net.ListenConfig

View File

@@ -0,0 +1,229 @@
// Package supervisor runs ordered session profiles with failover.
package supervisor
import (
"context"
"errors"
"fmt"
"time"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
)
const DefaultRetryDelay = 2 * time.Second
const DefaultHistoryLimit = 20
const (
// EventProfileStart marks a profile attempt starting.
EventProfileStart = "profile_start"
// EventProfileEnd marks a profile attempt ending.
EventProfileEnd = "profile_end"
)
var (
// ErrNoProfiles is returned when the supervisor is started without profiles.
ErrNoProfiles = errors.New("supervisor: no profiles configured")
// ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes.
ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded")
)
// Profile is one runnable session configuration in an ordered failover list.
type Profile struct {
Name string
Config session.Config
}
// ProfileStatus summarizes one profile's failover history.
type ProfileStatus struct {
Name string
Starts int
Failures int
CleanEnds int
LastStarted time.Time
LastEnded time.Time
LastError string
}
// Event is one bounded failover history entry.
type Event struct {
Time time.Time
Type string
Profile string
Cycle int
Error string
}
// Status is a point-in-time view of the supervisor.
type Status struct {
Cycle int
ActiveProfile string
ActiveProfileIndex int
Profiles []ProfileStatus
History []Event
LastError string
}
// Runner starts one session profile and blocks until it ends or fails.
type Runner func(ctx context.Context, cfg session.Config) error
// Config controls ordered failover behavior.
type Config struct {
Profiles []Profile
RetryDelay time.Duration
MaxCycles int
OnProfileStart func(profile Profile, cycle int)
OnProfileEnd func(profile Profile, cycle int, err error)
OnStatus func(status Status)
HistoryLimit int
}
// Run starts profiles in order. If a profile exits while ctx is still active,
// the supervisor waits RetryDelay and advances to the next profile.
func Run(ctx context.Context, cfg Config, run Runner) error {
if len(cfg.Profiles) == 0 {
return ErrNoProfiles
}
if cfg.RetryDelay == 0 {
cfg.RetryDelay = DefaultRetryDelay
}
state := newStatusTracker(cfg.Profiles, cfg.HistoryLimit, cfg.OnStatus)
var lastErr error
for cycle := 1; ; cycle++ {
for i, profile := range cfg.Profiles {
if ctx.Err() != nil {
return nil
}
state.start(i, cycle)
if cfg.OnProfileStart != nil {
cfg.OnProfileStart(profile, cycle)
}
err := run(ctx, profile.Config)
if ctx.Err() != nil {
return nil
}
if err != nil {
lastErr = fmt.Errorf("profile %q: %w", profile.Name, err)
} else {
lastErr = fmt.Errorf("profile %q ended", profile.Name)
}
state.end(i, cycle, err)
if cfg.OnProfileEnd != nil {
cfg.OnProfileEnd(profile, cycle, err)
}
if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 {
return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr)
}
if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil {
return nil
}
}
}
}
type statusTracker struct {
status Status
notify func(Status)
historyLimit int
}
func newStatusTracker(profiles []Profile, historyLimit int, notify func(Status)) *statusTracker {
if historyLimit == 0 {
historyLimit = DefaultHistoryLimit
}
statusProfiles := make([]ProfileStatus, 0, len(profiles))
for _, profile := range profiles {
statusProfiles = append(statusProfiles, ProfileStatus{Name: profile.Name})
}
return &statusTracker{
status: Status{
ActiveProfileIndex: -1,
Profiles: statusProfiles,
},
notify: notify,
historyLimit: historyLimit,
}
}
func (t *statusTracker) start(profileIndex, cycle int) {
now := time.Now()
profile := &t.status.Profiles[profileIndex]
profile.Starts++
profile.LastStarted = now
t.status.Cycle = cycle
t.status.ActiveProfile = profile.Name
t.status.ActiveProfileIndex = profileIndex
t.appendHistory(Event{
Time: now,
Type: EventProfileStart,
Profile: profile.Name,
Cycle: cycle,
})
t.emit()
}
func (t *statusTracker) end(profileIndex, cycle int, err error) {
now := time.Now()
profile := &t.status.Profiles[profileIndex]
profile.LastEnded = now
event := Event{
Time: now,
Type: EventProfileEnd,
Profile: profile.Name,
Cycle: cycle,
}
if err != nil {
profile.Failures++
profile.LastError = err.Error()
t.status.LastError = fmt.Sprintf("profile %q: %v", profile.Name, err)
event.Error = err.Error()
} else {
profile.CleanEnds++
profile.LastError = ""
t.status.LastError = fmt.Sprintf("profile %q ended", profile.Name)
}
t.status.ActiveProfile = ""
t.status.ActiveProfileIndex = -1
t.appendHistory(event)
t.emit()
}
func (t *statusTracker) appendHistory(event Event) {
if t.historyLimit < 0 {
return
}
t.status.History = append(t.status.History, event)
if len(t.status.History) > t.historyLimit {
t.status.History = t.status.History[len(t.status.History)-t.historyLimit:]
}
}
func (t *statusTracker) emit() {
if t.notify == nil {
return
}
t.notify(cloneStatus(t.status))
}
func cloneStatus(status Status) Status {
status.Profiles = append([]ProfileStatus(nil), status.Profiles...)
status.History = append([]Event(nil), status.History...)
return status
}
func waitRetryDelay(ctx context.Context, delay time.Duration) error {
if delay <= 0 {
return nil
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}

View File

@@ -0,0 +1,170 @@
package supervisor
import (
"context"
"errors"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
)
var errRunnerBoom = errors.New("boom")
func TestRunRequiresProfiles(t *testing.T) {
err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil })
if !errors.Is(err, ErrNoProfiles) {
t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles)
}
}
func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) {
profiles := []Profile{
{Name: "first", Config: session.Config{Auth: "wbstream"}},
{Name: "second", Config: session.Config{Auth: "jitsi"}},
}
var started []string
var ended []string
err := Run(context.Background(), Config{
Profiles: profiles,
RetryDelay: -1,
MaxCycles: 1,
OnProfileStart: func(profile Profile, cycle int) {
started = append(started, profile.Name)
if cycle != 1 {
t.Fatalf("cycle = %d, want 1", cycle)
}
},
OnProfileEnd: func(profile Profile, _ int, err error) {
ended = append(ended, profile.Name)
if !errors.Is(err, errRunnerBoom) {
t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom)
}
},
}, func(_ context.Context, cfg session.Config) error {
if cfg.Auth == "" {
t.Fatal("runner received empty auth")
}
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if got, want := started, []string{"first", "second"}; !equalStrings(got, want) {
t.Fatalf("started = %v, want %v", got, want)
}
if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) {
t.Fatalf("ended = %v, want %v", got, want)
}
}
func TestRunEmitsStatusHistory(t *testing.T) {
profiles := []Profile{
{Name: "first", Config: session.Config{Auth: "wbstream"}},
{Name: "second", Config: session.Config{Auth: "jitsi"}},
}
var snapshots []Status
err := Run(context.Background(), Config{
Profiles: profiles,
RetryDelay: -1,
MaxCycles: 1,
HistoryLimit: 3,
OnStatus: func(status Status) {
snapshots = append(snapshots, status)
},
}, func(_ context.Context, cfg session.Config) error {
if cfg.Auth == "first" {
t.Fatal("runner received profile name instead of config")
}
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if len(snapshots) != 4 {
t.Fatalf("status snapshots = %d, want 4", len(snapshots))
}
first := snapshots[0]
if first.ActiveProfile != "first" || first.ActiveProfileIndex != 0 || first.Cycle != 1 {
t.Fatalf("first status = %+v", first)
}
if first.Profiles[0].Starts != 1 || first.Profiles[0].LastStarted.IsZero() {
t.Fatalf("first profile start status = %+v", first.Profiles[0])
}
last := snapshots[len(snapshots)-1]
if last.ActiveProfile != "" || last.ActiveProfileIndex != -1 {
t.Fatalf("last active status = %+v", last)
}
if last.Profiles[0].Failures != 1 || last.Profiles[1].Failures != 1 {
t.Fatalf("profile failures = %+v", last.Profiles)
}
if last.LastError == "" || last.Profiles[1].LastError == "" {
t.Fatalf("last errors missing: %+v", last)
}
if len(last.History) != 3 {
t.Fatalf("history length = %d, want 3", len(last.History))
}
if last.History[0].Type != EventProfileEnd || last.History[0].Profile != "first" {
t.Fatalf("oldest bounded history event = %+v", last.History[0])
}
if last.History[2].Type != EventProfileEnd || last.History[2].Profile != "second" ||
last.History[2].Error == "" {
t.Fatalf("last history event = %+v", last.History[2])
}
}
func TestRunStatusSnapshotIsImmutable(t *testing.T) {
var first Status
var second Status
err := Run(context.Background(), Config{
Profiles: []Profile{{Name: "one"}},
RetryDelay: -1,
MaxCycles: 1,
OnStatus: func(status Status) {
if first.Profiles == nil {
first = status
first.Profiles[0].Starts = 99
first.History[0].Profile = "mutated"
return
}
second = status
},
}, func(context.Context, session.Config) error {
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if first.Profiles[0].Starts != 99 || first.History[0].Profile != "mutated" {
t.Fatalf("test mutation did not apply to snapshot: %+v", first)
}
if second.Profiles[0].Starts != 1 || second.History[0].Profile != "one" {
t.Fatalf("snapshot mutation leaked into later status: %+v", second)
}
}
func TestRunReturnsNilOnContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
err := Run(ctx, Config{
Profiles: []Profile{{Name: "one"}},
RetryDelay: time.Hour,
}, func(context.Context, session.Config) error {
cancel()
return nil
})
if err != nil {
t.Fatalf("Run() error = %v, want nil", err)
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -35,6 +35,7 @@ const (
protocolVersion byte = 1
frameTypeData byte = 1
frameTypeAck byte = 2
frameTypeHello byte = 3
)
var (
@@ -86,6 +87,7 @@ type streamTransport struct {
nextSeq atomic.Uint32
closed atomic.Bool
writerUp atomic.Bool
peerReady atomic.Bool
sendMu sync.Mutex
startWriter sync.Once
ackMu sync.Mutex
@@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) {
// CanSend reports whether transport is ready for sending.
func (p *streamTransport) CanSend() bool {
return !p.closed.Load() && p.stream.CanSend()
return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend()
}
// Features describes the current seichannel transport semantics.
@@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() {
ticker := time.NewTicker(p.effectiveFrameInterval())
defer ticker.Stop()
idle := buildVideoAccessUnit(nil)
idle := buildVideoAccessUnit(encodeHelloFrame())
for {
select {
@@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) {
}
switch frame.typ {
case frameTypeHello:
p.peerReady.Store(true)
case frameTypeAck:
p.peerReady.Store(true)
p.resolveAck(frame.seq, frame.crc)
case frameTypeData:
p.peerReady.Store(true)
p.handleInboundFrame(frame)
}
}
@@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
out[5] = frameTypeData
binary.BigEndian.PutUint32(out[6:10], seq)
binary.BigEndian.PutUint32(out[10:14], crc)
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
copy(out[22:], payload)
return out
@@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte {
return out
}
func encodeHelloFrame() []byte {
out := make([]byte, 6)
binary.BigEndian.PutUint32(out[0:4], protocolMagic)
out[4] = protocolVersion
out[5] = frameTypeHello
return out
}
func decodeTransportFrame(data []byte) (transportFrame, error) {
if len(data) < 6 {
return transportFrame{}, ErrFrameTooShort
@@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
frame := transportFrame{typ: data[5]}
switch frame.typ {
case frameTypeHello:
return frame, nil
case frameTypeAck:
if len(data) < 14 {
return transportFrame{}, ErrAckTooShort

View File

@@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) {
t.Fatalf("payload mismatch: got=%q", decoded.payload)
}
}
func TestHelloFrameRoundTrip(t *testing.T) {
hello, err := decodeTransportFrame(encodeHelloFrame())
if err != nil {
t.Fatalf("decodeTransportFrame(hello) failed: %v", err)
}
if hello.typ != frameTypeHello {
t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello)
}
}

View File

@@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) {
if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched {
t.Fatal("callbacks/watch were not forwarded")
}
if tr.CanSend() {
t.Fatal("CanSend() = true before peer hello")
}
tr.handleSample(buildVideoAccessUnit(encodeHelloFrame()))
if !tr.CanSend() {
t.Fatal("CanSend() = false, want true")
t.Fatal("CanSend() = false after peer hello")
}
if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description
t.Fatalf("Features() = %+v", features)

View File

@@ -0,0 +1,91 @@
package transport
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"sync"
"time"
)
var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size")
type trafficTransport struct {
inner Transport
maxPayloadSize int
minDelay time.Duration
maxDelay time.Duration
sendMu sync.Mutex
}
// WithTraffic wraps tr with optional payload caps and send pacing.
func WithTraffic(tr Transport, cfg TrafficConfig) Transport {
if tr == nil {
return nil
}
cfg = effectiveTrafficConfig(tr.Features(), cfg)
if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 {
return tr
}
return &trafficTransport{
inner: tr,
maxPayloadSize: cfg.MaxPayloadSize,
minDelay: cfg.MinDelay,
maxDelay: cfg.MaxDelay,
}
}
func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig {
if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize {
cfg.MaxPayloadSize = features.MaxPayloadSize
}
return cfg
}
func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) }
func (t *trafficTransport) Send(data []byte) error {
t.sendMu.Lock()
defer t.sendMu.Unlock()
if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize {
return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize)
}
if delay := t.nextDelay(); delay > 0 {
time.Sleep(delay)
}
return t.inner.Send(data)
}
func (t *trafficTransport) Close() error { return t.inner.Close() }
func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) }
func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) }
func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) }
func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) }
func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() }
func (t *trafficTransport) Features() Features {
features := t.inner.Features()
if t.maxPayloadSize > 0 &&
(features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) {
features.MaxPayloadSize = t.maxPayloadSize
}
return features
}
func (t *trafficTransport) nextDelay() time.Duration {
if t.maxDelay <= 0 && t.minDelay <= 0 {
return 0
}
minDelay := t.minDelay
maxDelay := t.maxDelay
if maxDelay <= minDelay {
return minDelay
}
return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter
}

View File

@@ -0,0 +1,67 @@
package transport
import (
"context"
"errors"
"testing"
"time"
)
type trafficStubTransport struct {
features Features
sent [][]byte
}
func (s *trafficStubTransport) Connect(context.Context) error { return nil }
func (s *trafficStubTransport) Send(data []byte) error {
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
func (s *trafficStubTransport) Close() error { return nil }
func (s *trafficStubTransport) SetReconnectCallback(func()) {}
func (s *trafficStubTransport) SetShouldReconnect(func() bool) {}
func (s *trafficStubTransport) SetEndedCallback(func(string)) {}
func (s *trafficStubTransport) WatchConnection(context.Context) {}
func (s *trafficStubTransport) CanSend() bool { return true }
func (s *trafficStubTransport) Features() Features { return s.features }
func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) {
inner := &trafficStubTransport{}
got := WithTraffic(inner, TrafficConfig{})
if got != inner {
t.Fatalf("WithTraffic disabled returned %T, want inner", got)
}
}
func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) {
inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}}
tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10})
if features := tr.Features(); features.MaxPayloadSize != 5 {
t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize)
}
err := tr.Send([]byte("123456"))
if !errors.Is(err, ErrTrafficPayloadTooLarge) {
t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge)
}
if len(inner.sent) != 0 {
t.Fatalf("inner sent %d payloads, want 0", len(inner.sent))
}
if err := tr.Send([]byte("12345")); err != nil {
t.Fatalf("Send(max sized) error = %v", err)
}
if got := string(inner.sent[0]); got != "12345" {
t.Fatalf("inner payload = %q, want 12345", got)
}
}
func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) {
inner := &trafficStubTransport{}
tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond})
start := time.Now()
if err := tr.Send([]byte("x")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if elapsed := time.Since(start); elapsed < 2*time.Millisecond {
t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed)
}
}

View File

@@ -4,6 +4,7 @@ package transport
import (
"context"
"errors"
"time"
)
var (
@@ -32,10 +33,17 @@ type Transport interface {
Features() Features
}
// TrafficConfig controls optional reliability-oriented send shaping.
type TrafficConfig struct {
MaxPayloadSize int
MinDelay time.Duration
MaxDelay time.Duration
}
// Config holds common transport configuration.
type Config struct {
Carrier string
RoomURL string
Carrier string
RoomURL string
// Engine, URL, Token are forwarded to carrier.Config for the "none" auth
// carrier (direct engine access without a service-specific auth flow).
Engine string
@@ -63,6 +71,7 @@ type Config struct {
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Traffic TrafficConfig
}
// Factory creates a transport instance.
@@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) {
if !ok {
return nil, ErrTransportNotFound
}
return factory(ctx, cfg)
tr, err := factory(ctx, cfg)
if err != nil {
return nil, err
}
return WithTraffic(tr, cfg.Traffic), nil
}
// Available returns a list of registered transport names.

View File

@@ -15,6 +15,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/protect"
@@ -65,23 +66,26 @@ const (
)
var (
mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional
defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional
defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional
registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional
mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional
defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional
defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional
registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional
runClientWithReady = client.RunWithReady //nolint:gochecknoglobals // package-level state intentional
cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional
done chan struct{} //nolint:gochecknoglobals // package-level state intentional
ready chan struct{} //nolint:gochecknoglobals // package-level state intentional
cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional
done chan struct{} //nolint:gochecknoglobals // package-level state intentional
ready chan struct{} //nolint:gochecknoglobals // package-level state intentional
errRun error
)
type mobileConfig struct {
link string
transport string
dnsServer string
vp8FPS int
vp8BatchSize int
link string
transport string
dnsServer string
vp8FPS int
vp8BatchSize int
livenessInterval time.Duration
livenessTimeout time.Duration
livenessFailures int
}
// SetProtector sets the Android VPN socket protector.
@@ -143,6 +147,21 @@ func SetVP8Options(fps, batchSize int) {
defaults.vp8BatchSize = clampAtLeastOne(batchSize, 64)
}
// SetLivenessOptions configures control-stream ping/pong checks.
// Values <= 0 reset that field to its default. Durations are milliseconds.
func SetLivenessOptions(intervalMillis, timeoutMillis, failures int) {
mu.Lock()
defer mu.Unlock()
ensureDefaultConfigLocked()
defaults.livenessInterval = durationFromMillisOrDefault(intervalMillis, control.DefaultInterval)
defaults.livenessTimeout = durationFromMillisOrDefault(timeoutMillis, control.DefaultTimeout)
if failures <= 0 {
defaults.livenessFailures = control.DefaultFailures
return
}
defaults.livenessFailures = failures
}
// SetDebug enables or disables verbose logging.
func SetDebug(enabled bool) {
logger.SetVerbose(enabled)
@@ -195,6 +214,11 @@ func Check(
vp8BatchSize int,
) (int64, error) {
registerDefaults()
mu.Lock()
ensureDefaultConfigLocked()
cfg := defaults
mu.Unlock()
carrierName = normalizeCarrier(carrierName)
transportName = normalizeTransport(transportName)
if err := validateStartArgs(carrierName, roomID, clientID, keyHex); err != nil {
@@ -227,6 +251,7 @@ func Check(
DNSServer: defaultDNSServer,
VP8FPS: clampAtLeastOne(vp8FPS, 120),
VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64),
Liveness: livenessConfig(cfg),
},
func() {
readyOnce.Do(func() {
@@ -271,6 +296,11 @@ func Ping(
vp8BatchSize int,
) (int64, error) {
registerDefaults()
mu.Lock()
ensureDefaultConfigLocked()
cfg := defaults
mu.Unlock()
carrierName = normalizeCarrier(carrierName)
transportName = normalizeTransport(transportName)
@@ -310,6 +340,7 @@ func Ping(
DNSServer: defaultDNSServer,
VP8FPS: clampAtLeastOne(vp8FPS, 120),
VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64),
Liveness: livenessConfig(cfg),
},
func() {
readyOnce.Do(func() {
@@ -557,6 +588,7 @@ func startWithConfig(
SOCKSPass: socksPass,
VP8FPS: cfg.vp8FPS,
VP8BatchSize: cfg.vp8BatchSize,
Liveness: livenessConfig(cfg),
},
func() {
readyOnce.Do(func() {
@@ -576,6 +608,7 @@ func startWithConfig(
}
// WaitReady blocks until the selected transport is connected and the local SOCKS5 listener is ready.
//
//nolint:cyclop // straightforward state-machine waits with multiple terminal conditions
func WaitReady(timeoutMillis int) error {
mu.Lock()
@@ -666,15 +699,38 @@ func waitForCheckDone(doneCh <-chan error) {
func ensureDefaultConfigLocked() {
defaultsSet.Do(func() {
defaults = mobileConfig{
link: defaultLink,
transport: defaultTransport,
dnsServer: defaultDNSServer,
vp8FPS: 60,
vp8BatchSize: 8,
link: defaultLink,
transport: defaultTransport,
dnsServer: defaultDNSServer,
vp8FPS: 60,
vp8BatchSize: 8,
livenessInterval: control.DefaultInterval,
livenessTimeout: control.DefaultTimeout,
livenessFailures: control.DefaultFailures,
}
})
}
func livenessConfig(cfg mobileConfig) control.Config {
interval := cfg.livenessInterval
if interval <= 0 {
interval = control.DefaultInterval
}
timeout := cfg.livenessTimeout
if timeout <= 0 {
timeout = control.DefaultTimeout
}
failures := cfg.livenessFailures
if failures <= 0 {
failures = control.DefaultFailures
}
return control.Config{
Interval: interval,
Timeout: timeout,
Failures: failures,
}
}
func normalizeTransport(value string) string {
switch value {
case dataTransport, "data", "dc":
@@ -734,6 +790,17 @@ func clampAtLeastOne(value, maxValue int) int {
return value
}
func durationFromMillisOrDefault(value int, def time.Duration) time.Duration {
if value <= 0 {
return def
}
d := time.Duration(value) * time.Millisecond
if d <= 0 {
return def
}
return d
}
// logBridge adapts LogWriter to io.Writer.
type logBridge struct {
w LogWriter

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/protect"
)
@@ -83,12 +84,15 @@ func TestDefaultsAndSetters(t *testing.T) {
SetLink("direct")
SetDNS("9.9.9.9:53")
SetVP8Options(-1, 999)
SetLivenessOptions(2500, 750, -1)
mu.Lock()
got := defaults
mu.Unlock()
if got.transport != dataTransport || got.link != defaultLink || got.dnsServer != "9.9.9.9:53" ||
got.vp8FPS != 1 || got.vp8BatchSize != 64 {
got.vp8FPS != 1 || got.vp8BatchSize != 64 ||
got.livenessInterval != 2500*time.Millisecond || got.livenessTimeout != 750*time.Millisecond ||
got.livenessFailures != control.DefaultFailures {
t.Fatalf("defaults = %+v", got)
}
@@ -168,15 +172,19 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) {
t.Cleanup(func() {
resetMobileGlobals(t)
})
SetLivenessOptions(2500, 750, 4)
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz ||
cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 {
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 ||
cfg.Liveness.Interval != 2500*time.Millisecond ||
cfg.Liveness.Timeout != 750*time.Millisecond ||
cfg.Liveness.Failures != 4 {
t.Fatalf(
"RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d",
"RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d liveness=%+v",
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID,
cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize,
cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness,
)
}
onReady()
@@ -208,9 +216,12 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
if cfg.Transport != defaultTransport || cfg.RoomURL != "https://telemost.yandex.ru/j/room" ||
cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" {
t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q",
cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass)
cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" ||
cfg.Liveness.Interval != control.DefaultInterval ||
cfg.Liveness.Timeout != control.DefaultTimeout ||
cfg.Liveness.Failures != control.DefaultFailures {
t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q liveness=%+v",
cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass, cfg.Liveness)
}
onReady()
<-ctx.Done()
@@ -225,9 +236,14 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
}
Stop()
SetLivenessOptions(3000, 1000, 5)
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 {
t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d", cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize)
if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 ||
cfg.Liveness.Interval != 3000*time.Millisecond ||
cfg.Liveness.Timeout != time.Second ||
cfg.Liveness.Failures != 5 {
t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d liveness=%+v",
cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness)
}
onReady()
<-ctx.Done()
@@ -242,6 +258,32 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
}
}
func TestPingPassesLiveness(t *testing.T) {
resetMobileGlobals(t)
t.Cleanup(func() {
resetMobileGlobals(t)
})
SetLivenessOptions(4000, 1500, 6)
seen := make(chan control.Config, 1)
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
seen <- cfg.Liveness
onReady()
<-ctx.Done()
return nil
}
_, _ = Ping("jazz", "dc", "", "client", "key", 1085, 100, "http://127.0.0.1/", 30, 1)
select {
case got := <-seen:
if got.Interval != 4000*time.Millisecond || got.Timeout != 1500*time.Millisecond || got.Failures != 6 {
t.Fatalf("Ping liveness = %+v", got)
}
default:
t.Fatal("Ping did not start client")
}
}
func TestCheckTimeoutAndRunError(t *testing.T) {
resetMobileGlobals(t)
t.Cleanup(func() {