mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-06-07 04:49:43 +00:00
Merge pull request #58 from cyber-debug/refine/livekit-reconnect
refine livekit reconnect and liveness
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
# Prerequisites
|
||||
*.d
|
||||
.DS_Store
|
||||
|
||||
# Object files
|
||||
*.o
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
configpkg "github.com/openlibrecommunity/olcrtc/internal/config"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport/videochannel"
|
||||
)
|
||||
|
||||
@@ -35,6 +36,9 @@ var ErrConfigPathRequired = errors.New("usage: olcrtc <config.yaml>")
|
||||
// ErrDataDirRequired is returned when the YAML config does not specify a data directory.
|
||||
var ErrDataDirRequired = errors.New("data directory required (set 'data:' in YAML)")
|
||||
|
||||
// ErrProfilesUnsupportedForGen is returned when failover profiles are configured for gen mode.
|
||||
var ErrProfilesUnsupportedForGen = errors.New("profiles are only supported for srv and cnc modes")
|
||||
|
||||
//nolint:gochecknoglobals // Tests replace the long-running session runner with a bounded function.
|
||||
var runSession = session.Run
|
||||
|
||||
@@ -44,11 +48,18 @@ var runGen = execGen
|
||||
// loadedConfig bundles the parsed YAML file and the derived session config.
|
||||
type loadedConfig struct {
|
||||
scfg session.Config
|
||||
profiles []supervisor.Profile
|
||||
failover failoverConfig
|
||||
dataDir string
|
||||
debug bool
|
||||
ffmpegPath string
|
||||
}
|
||||
|
||||
type failoverConfig struct {
|
||||
retryDelay time.Duration
|
||||
maxCycles int
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
logger.Error(err)
|
||||
@@ -79,14 +90,44 @@ func loadConfig(path string) (loadedConfig, error) {
|
||||
if err != nil {
|
||||
return loadedConfig{}, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
base := configpkg.Apply(session.Config{}, f)
|
||||
profiles := make([]supervisor.Profile, 0, len(f.Profiles))
|
||||
for i, profile := range f.Profiles {
|
||||
name := profile.Name
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("profile-%d", i+1)
|
||||
}
|
||||
profiles = append(profiles, supervisor.Profile{
|
||||
Name: name,
|
||||
Config: configpkg.ApplyProfile(base, profile),
|
||||
})
|
||||
}
|
||||
failover, err := parseFailoverConfig(f.Failover)
|
||||
if err != nil {
|
||||
return loadedConfig{}, err
|
||||
}
|
||||
return loadedConfig{
|
||||
scfg: configpkg.Apply(session.Config{}, f),
|
||||
scfg: base,
|
||||
profiles: profiles,
|
||||
failover: failover,
|
||||
dataDir: f.Data,
|
||||
debug: f.Debug,
|
||||
ffmpegPath: f.FFmpeg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseFailoverConfig(f configpkg.Failover) (failoverConfig, error) {
|
||||
retryDelay := supervisor.DefaultRetryDelay
|
||||
if f.RetryDelay != "" {
|
||||
parsed, err := time.ParseDuration(f.RetryDelay)
|
||||
if err != nil {
|
||||
return failoverConfig{}, fmt.Errorf("parse failover.retry_delay: %w", err)
|
||||
}
|
||||
retryDelay = parsed
|
||||
}
|
||||
return failoverConfig{retryDelay: retryDelay, maxCycles: f.MaxCycles}, nil
|
||||
}
|
||||
|
||||
func runWithConfig(cfg loadedConfig) error {
|
||||
configureLogging(cfg.debug)
|
||||
|
||||
@@ -98,19 +139,116 @@ func runWithConfig(cfg loadedConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate config: %w", err)
|
||||
}
|
||||
scfg = session.ApplyTransportDefaults(scfg)
|
||||
scfg = session.ApplyLivenessDefaults(scfg)
|
||||
|
||||
if scfg.Mode == modeGen {
|
||||
if len(cfg.profiles) > 0 {
|
||||
return ErrProfilesUnsupportedForGen
|
||||
}
|
||||
return runGen(scfg)
|
||||
}
|
||||
|
||||
if len(cfg.profiles) > 0 {
|
||||
profiles, err := prepareProfiles(cfg.profiles)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runFailoverSessionMode(cfg.dataDir, profiles, cfg.failover)
|
||||
}
|
||||
|
||||
return runSessionMode(cfg.dataDir, scfg)
|
||||
}
|
||||
|
||||
func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error) {
|
||||
out := make([]supervisor.Profile, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
scfg, err := session.ApplyAuthDefaults(profile.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err)
|
||||
}
|
||||
profile.Config = session.ApplyLivenessDefaults(session.ApplyTransportDefaults(scfg))
|
||||
out = append(out, profile)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runSessionMode(dataDir string, scfg session.Config) error {
|
||||
if err := session.Validate(scfg); err != nil {
|
||||
return fmt.Errorf("validate config: %w", err)
|
||||
}
|
||||
|
||||
if err := prepareRuntimeData(dataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runManaged(func(ctx context.Context) error {
|
||||
return runSession(ctx, scfg)
|
||||
})
|
||||
}
|
||||
|
||||
func runFailoverSessionMode(dataDir string, profiles []supervisor.Profile, failover failoverConfig) error {
|
||||
for _, profile := range profiles {
|
||||
if err := session.Validate(profile.Config); err != nil {
|
||||
return fmt.Errorf("validate profile %q: %w", profile.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := prepareRuntimeData(dataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runManaged(func(ctx context.Context) error {
|
||||
return supervisor.Run(ctx, supervisor.Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: failover.retryDelay,
|
||||
MaxCycles: failover.maxCycles,
|
||||
OnProfileStart: func(profile supervisor.Profile, cycle int) {
|
||||
logger.Infof("failover cycle=%d starting profile=%s carrier=%s transport=%s",
|
||||
cycle, profile.Name, profile.Config.Auth, profile.Config.Transport)
|
||||
},
|
||||
OnProfileEnd: func(profile supervisor.Profile, cycle int, err error) {
|
||||
if err != nil {
|
||||
logger.Warnf("failover cycle=%d profile=%s ended with error: %v", cycle, profile.Name, err)
|
||||
return
|
||||
}
|
||||
logger.Warnf("failover cycle=%d profile=%s ended", cycle, profile.Name)
|
||||
},
|
||||
OnStatus: logFailoverStatus,
|
||||
}, runSession)
|
||||
})
|
||||
}
|
||||
|
||||
func logFailoverStatus(status supervisor.Status) {
|
||||
if !logger.IsVerbose() {
|
||||
return
|
||||
}
|
||||
active := status.ActiveProfile
|
||||
if active == "" {
|
||||
active = "none"
|
||||
}
|
||||
logger.Debugf("failover status cycle=%d active=%s last_error=%q profiles=%s history=%d",
|
||||
status.Cycle, active, status.LastError, formatProfileStatuses(status.Profiles), len(status.History))
|
||||
}
|
||||
|
||||
func formatProfileStatuses(profiles []supervisor.ProfileStatus) string {
|
||||
if len(profiles) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
for i, profile := range profiles {
|
||||
if i > 0 {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s{starts=%d failures=%d clean=%d}",
|
||||
profile.Name, profile.Starts, profile.Failures, profile.CleanEnds)
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func prepareRuntimeData(dataDir string) error {
|
||||
if dataDir == "" {
|
||||
return ErrDataDirRequired
|
||||
}
|
||||
@@ -124,6 +262,10 @@ func runSessionMode(dataDir string, scfg session.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runManaged(run func(context.Context) error) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -132,7 +274,7 @@ func runSessionMode(dataDir string, scfg session.Config) error {
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- runSession(ctx, scfg)
|
||||
errCh <- run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
|
||||
)
|
||||
|
||||
var errBoom = errors.New("boom")
|
||||
@@ -149,6 +150,112 @@ data: `+dir+`
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithArgsAppliesTransportDefaults(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(names) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(surnames) error = %v", err)
|
||||
}
|
||||
|
||||
oldRunSession := runSession
|
||||
t.Cleanup(func() { runSession = oldRunSession })
|
||||
runSession = func(ctx context.Context, cfg session.Config) error {
|
||||
if cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1 {
|
||||
t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
yamlPath := writeYAML(t, `
|
||||
mode: srv
|
||||
link: direct
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: room
|
||||
crypto:
|
||||
key: key
|
||||
net:
|
||||
transport: vp8channel
|
||||
dns: 1.1.1.1:53
|
||||
data: `+dir+`
|
||||
`)
|
||||
|
||||
if err := runWithArgs([]string{yamlPath}); err != nil {
|
||||
t.Fatalf("runWithArgs() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithArgsFailoverProfiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(names) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(surnames) error = %v", err)
|
||||
}
|
||||
|
||||
oldRunSession := runSession
|
||||
t.Cleanup(func() { runSession = oldRunSession })
|
||||
var seen []string
|
||||
runSession = func(ctx context.Context, cfg session.Config) error {
|
||||
seen = append(seen, cfg.Auth+"/"+cfg.Transport)
|
||||
if cfg.Auth == "wbstream" && (cfg.VP8FPS != 25 || cfg.VP8BatchSize != 1) {
|
||||
t.Fatalf("VP8 defaults = fps %d batch %d, want 25/1", cfg.VP8FPS, cfg.VP8BatchSize)
|
||||
}
|
||||
return errBoom
|
||||
}
|
||||
|
||||
yamlPath := writeYAML(t, `
|
||||
mode: srv
|
||||
link: direct
|
||||
crypto:
|
||||
key: key
|
||||
net:
|
||||
dns: 1.1.1.1:53
|
||||
profiles:
|
||||
- name: wb-primary
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: room
|
||||
net:
|
||||
transport: vp8channel
|
||||
- name: jitsi-backup
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: https://meet.example/room
|
||||
net:
|
||||
transport: datachannel
|
||||
failover:
|
||||
retry_delay: -1ns
|
||||
max_cycles: 1
|
||||
data: `+dir+`
|
||||
`)
|
||||
|
||||
err := runWithArgs([]string{yamlPath})
|
||||
if !errors.Is(err, supervisor.ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("runWithArgs() error = %v, want %v", err, supervisor.ErrMaxCyclesExceeded)
|
||||
}
|
||||
want := []string{"wbstream/vp8channel", "jitsi/datachannel"}
|
||||
if !equalStrings(seen, want) {
|
||||
t.Fatalf("seen profiles = %v, want %v", seen, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithConfigRejectsProfilesInGenMode(t *testing.T) {
|
||||
cfg := loadedConfig{
|
||||
scfg: session.Config{Mode: modeGen},
|
||||
profiles: []supervisor.Profile{{Name: "one"}},
|
||||
}
|
||||
if err := runWithConfig(cfg); !errors.Is(err, ErrProfilesUnsupportedForGen) {
|
||||
t.Fatalf("runWithConfig() error = %v, want %v", err, ErrProfilesUnsupportedForGen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureLogging(t *testing.T) {
|
||||
t.Setenv("PION_LOG_DISABLE", "")
|
||||
logger.SetVerbose(false)
|
||||
@@ -170,6 +277,18 @@ func TestConfigureLogging(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestResolveDataDir(t *testing.T) {
|
||||
abs := filepath.Join(t.TempDir(), "data")
|
||||
got, err := resolveDataDir(abs)
|
||||
|
||||
@@ -234,7 +234,7 @@ internal/e2e/ E2E тесты на реальных провайдер
|
||||
|
||||
| Файл | Что делает |
|
||||
|---|---|
|
||||
| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для jazz с ретраями (wbstream больше не поддерживает автогенерацию - руму нужно создавать вручную через stream.wb.ru) |
|
||||
| `session.go` | Главная точка конфигурации. `RegisterDefaults()` регистрирует все carriers, links, transports. `Validate()` проверяет все настройки. `Run()` роутит в `server.Run` или `client.Run`. `Gen()` генерирует Room ID для auth-провайдеров с `RoomCreator` и ретраями |
|
||||
| `session_test.go` | Тесты валидации конфига |
|
||||
|
||||
### `internal/config/`
|
||||
@@ -452,7 +452,7 @@ Carrier - это WebRTC сервис видеозвонков, через кот
|
||||
- Минимальная прослойка, почти прямой relay
|
||||
- Работает с vp8channel, seichannel, videochannel
|
||||
- DataChannel **не работает** в обычном guest flow: WB Stream выдаёт токены с `canPublishData=false`, DC не маршрутизирует данные (expected fail в E2E тестах)
|
||||
- Room ID нужно создавать вручную через stream.wb.ru
|
||||
- Room ID можно создать вручную через stream.wb.ru или через `mode: gen`
|
||||
- Инициализация звонка автоматически
|
||||
|
||||
---
|
||||
|
||||
@@ -14,12 +14,28 @@ room:
|
||||
id: "https://meet.cryptopro.ru/REPLACE_WITH_ROOM_NAME"
|
||||
|
||||
crypto:
|
||||
# Or use key_file: "./olcrtc.key" to keep the secret out of this file.
|
||||
key: "REPLACE_ME_WITH_64_HEX_CHARS" # must match the server
|
||||
|
||||
net:
|
||||
transport: datachannel # must match the server
|
||||
dns: "8.8.8.8:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
# Optional planned rebuild for long-running calls.
|
||||
# lifecycle:
|
||||
# max_session_duration: 6h
|
||||
|
||||
# Optional reliability shaping for encrypted wire messages.
|
||||
# traffic:
|
||||
# max_payload_size: 4096
|
||||
# min_delay: 5ms
|
||||
# max_delay: 30ms
|
||||
|
||||
# Local SOCKS5 listener exposed to applications
|
||||
socks:
|
||||
host: "127.0.0.1"
|
||||
|
||||
@@ -11,6 +11,7 @@ olcrtc /etc/olcrtc/server.yaml
|
||||
|
||||
- [`server.example.yaml`](./server.example.yaml)
|
||||
- [`client.example.yaml`](./client.example.yaml)
|
||||
- [`failover.example.yaml`](./failover.example.yaml)
|
||||
|
||||
## Схема
|
||||
|
||||
@@ -20,7 +21,7 @@ olcrtc /etc/olcrtc/server.yaml
|
||||
| `link` | `direct` |
|
||||
| `auth.provider` | `jitsi`, `telemost`, `jazz`, `wbstream`, `none` |
|
||||
| `room.id` | conference room id |
|
||||
| `crypto.key` | 64-char hex (32 bytes) |
|
||||
| `crypto.key` / `crypto.key_file` | 64-char hex (32 bytes), inline or read from file |
|
||||
| `net.transport` | `datachannel`, `videochannel`, `seichannel`, `vp8channel` |
|
||||
| `net.dns` | resolver `host:port` |
|
||||
| `socks.host` / `.port` | client-side listener |
|
||||
@@ -30,7 +31,126 @@ olcrtc /etc/olcrtc/server.yaml
|
||||
| `video.*` | videochannel tuning |
|
||||
| `vp8.*` | vp8channel tuning |
|
||||
| `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning |
|
||||
| `liveness.interval` | control-stream ping interval, default `10s` |
|
||||
| `liveness.timeout` | pong timeout, default `5s` |
|
||||
| `liveness.failures` | missed pongs before reconnect, default `3` |
|
||||
| `lifecycle.max_session_duration` | planned session rebuild interval, e.g. `6h`; unset = off |
|
||||
| `traffic.max_payload_size` | safe encrypted wire-message cap; `0` = transport default |
|
||||
| `traffic.min_delay` / `.max_delay` | optional send pacing jitter, e.g. `5ms` / `30ms` |
|
||||
| `gen.amount` | gen mode: number of rooms to create |
|
||||
| `profiles[]` | ordered srv/cnc failover profiles |
|
||||
| `failover.retry_delay` | delay before trying the next profile, e.g. `2s` |
|
||||
| `failover.max_cycles` | stop after N full profile-list passes; `0` = forever |
|
||||
| `data` | path to data directory |
|
||||
| `debug` | verbose logging |
|
||||
| `ffmpeg` | path to ffmpeg binary |
|
||||
|
||||
`mode: cnc` refuses non-loopback `socks.host` values unless both
|
||||
`socks.user` and `socks.pass` are set.
|
||||
|
||||
`crypto.key_file` is resolved relative to the YAML file. Do not set it
|
||||
together with `crypto.key`.
|
||||
|
||||
## Liveness
|
||||
|
||||
After `CLIENT_HELLO` / `SERVER_WELCOME`, the first smux stream stays open as
|
||||
an encrypted control stream. olcrtc now sends `CONTROL_PING` / `CONTROL_PONG`
|
||||
messages over that stream to prove the real tunnel path still round-trips.
|
||||
This detects states where a provider or WebRTC layer looks connected but the
|
||||
encrypted smux path is no longer usable.
|
||||
|
||||
```yaml
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
```
|
||||
|
||||
When the failure threshold is reached, the current smux session is rebuilt.
|
||||
In failover mode, a profile that exits after liveness-triggered reconnect
|
||||
failure lets the supervisor advance to the next profile.
|
||||
|
||||
## Lifecycle Rotation
|
||||
|
||||
`lifecycle.max_session_duration` sets a planned upper bound for one provider
|
||||
call/session. When the duration expires, olcrtc cancels the active server or
|
||||
client session and starts a fresh one with the same config. While this option
|
||||
is enabled, clean session endings are also restarted so the peer that did not
|
||||
fire the timer can follow the rebuild. This is useful for long-running
|
||||
deployments where provider calls get stale, accumulate media state, or should
|
||||
be periodically re-created.
|
||||
|
||||
```yaml
|
||||
lifecycle:
|
||||
max_session_duration: 6h
|
||||
```
|
||||
|
||||
The field is optional and disabled when omitted. Values use Go duration syntax
|
||||
such as `30m`, `2h`, or `6h`; zero and negative durations are rejected.
|
||||
|
||||
## Traffic Shaping
|
||||
|
||||
`traffic` applies a shared reliability-oriented wrapper around the selected
|
||||
transport. It can cap encrypted wire-message size and add small send pacing
|
||||
delays without truncating data. When a payload would exceed the effective cap,
|
||||
the send fails clearly instead of cutting bytes and corrupting smux.
|
||||
|
||||
```yaml
|
||||
traffic:
|
||||
max_payload_size: 4096
|
||||
min_delay: 5ms
|
||||
max_delay: 30ms
|
||||
```
|
||||
|
||||
The wrapper clamps the configured payload cap to the selected transport's
|
||||
advertised `MaxPayloadSize`. Client and server also reduce smux frame size to
|
||||
fit the effective encrypted payload cap, accounting for crypto overhead. `0`
|
||||
adds no extra cap beyond the selected transport's advertised limit. Delays use
|
||||
Go duration syntax; if only `min_delay` is set, it is a fixed delay. Use the
|
||||
same traffic settings on both peers.
|
||||
|
||||
## Failover Profiles
|
||||
|
||||
`mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used
|
||||
as common defaults; each profile overrides only the fields it sets. The CLI
|
||||
runs profiles in order. If a profile fails or ends while the process is still
|
||||
alive, olcrtc waits `failover.retry_delay` and starts the next profile.
|
||||
|
||||
```yaml
|
||||
mode: srv
|
||||
link: direct
|
||||
crypto:
|
||||
key_file: ./olcrtc.key
|
||||
net:
|
||||
dns: "1.1.1.1:53"
|
||||
data: data
|
||||
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: "WB_ROOM_ID"
|
||||
net:
|
||||
transport: vp8channel
|
||||
|
||||
- name: jitsi-dc
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: "https://meet.example.org/olcrtc-room"
|
||||
net:
|
||||
transport: datachannel
|
||||
|
||||
failover:
|
||||
retry_delay: 2s
|
||||
max_cycles: 0
|
||||
```
|
||||
|
||||
Both peers must use compatible profile order and room settings. This first
|
||||
failover layer rebuilds the session on the next profile; active smux streams
|
||||
do not migrate, but new connections can recover on the next profile.
|
||||
|
||||
When `debug: true` is enabled, the CLI also emits a compact supervisor status
|
||||
snapshot with the active profile, per-profile start/failure counters, and
|
||||
bounded failover history size.
|
||||
|
||||
49
docs/failover.example.yaml
Normal file
49
docs/failover.example.yaml
Normal file
@@ -0,0 +1,49 @@
|
||||
# olcrtc failover config example
|
||||
# Use the same profile order on both peers.
|
||||
|
||||
mode: srv
|
||||
link: direct
|
||||
|
||||
crypto:
|
||||
key_file: "./olcrtc.key"
|
||||
|
||||
net:
|
||||
dns: "1.1.1.1:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
# Optional planned rebuild for each active profile.
|
||||
# lifecycle:
|
||||
# max_session_duration: 6h
|
||||
|
||||
# Optional reliability shaping for encrypted wire messages.
|
||||
# traffic:
|
||||
# max_payload_size: 4096
|
||||
# min_delay: 5ms
|
||||
# max_delay: 30ms
|
||||
|
||||
data: data
|
||||
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: "REPLACE_WITH_WB_ROOM_ID"
|
||||
net:
|
||||
transport: vp8channel
|
||||
|
||||
- name: jitsi-datachannel
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: "https://meet.example.org/REPLACE_WITH_ROOM_NAME"
|
||||
net:
|
||||
transport: datachannel
|
||||
|
||||
failover:
|
||||
retry_delay: 2s
|
||||
max_cycles: 0
|
||||
@@ -177,7 +177,7 @@ data: data
|
||||
|
||||
### wbstream + vp8channel (альтернатива)
|
||||
|
||||
Сначала создай руму вручную через сайт [wbstream](https://stream.wb.ru) (автогенерация через `mode: gen` для wbstream больше не поддерживается) и сохрани её ID.
|
||||
Создай руму через сайт [wbstream](https://stream.wb.ru) или заранее сгенерируй ID через `mode: gen` с `auth.provider: wbstream`.
|
||||
|
||||
`wbstream + datachannel` **не работает** в обычном guest flow — WB Stream выдаёт токены с `canPublishData=false`, и DC не маршрутизирует данные. Для обычного использования выбирай `vp8channel`.
|
||||
|
||||
|
||||
425
docs/project-map.md
Normal file
425
docs/project-map.md
Normal file
@@ -0,0 +1,425 @@
|
||||
# olcRTC Project Map
|
||||
|
||||
This is a developer map for finding the useful parts of the project quickly.
|
||||
It focuses on code ownership, runtime flow, extension points, and areas that
|
||||
are worth deeper work.
|
||||
|
||||
## One-Sentence Model
|
||||
|
||||
olcRTC is an encrypted TCP-over-WebRTC tunnel: the client exposes a local
|
||||
SOCKS5 listener, the server dials requested TCP targets, and both sides carry
|
||||
the smux byte stream through a selected WebRTC carrier and transport.
|
||||
|
||||
## Runtime Stack
|
||||
|
||||
```text
|
||||
YAML config
|
||||
-> cmd/olcrtc
|
||||
-> internal/config
|
||||
-> internal/app/session
|
||||
-> internal/server or internal/client
|
||||
-> internal/link/direct
|
||||
-> internal/transport/{datachannel,vp8channel,seichannel,videochannel}
|
||||
-> internal/carrier/builtin
|
||||
-> internal/auth/<provider> + internal/engine/<engine>
|
||||
-> external service SFU / signaling
|
||||
```
|
||||
|
||||
Tunnel data path:
|
||||
|
||||
```text
|
||||
local app
|
||||
-> client SOCKS5
|
||||
-> smux stream
|
||||
-> muxconn AEAD encrypt
|
||||
-> link.Send
|
||||
-> transport encoding
|
||||
-> carrier/engine
|
||||
-> SFU/service
|
||||
-> peer engine/carrier
|
||||
-> transport decoding
|
||||
-> muxconn AEAD decrypt
|
||||
-> smux stream
|
||||
-> server TCP dial
|
||||
-> target host
|
||||
```
|
||||
|
||||
## Entrypoints
|
||||
|
||||
| Path | Purpose |
|
||||
|---|---|
|
||||
| `cmd/olcrtc/main.go` | Main CLI. Accepts one YAML file, applies auth and transport defaults, starts `srv`, `cnc`, or `gen`. |
|
||||
| `cmd/olcrtc-cgo/main.go` | Small c-shared entrypoint for desktop/native consumers. |
|
||||
| `pkg/olcrtc` | Embeddable lower-level API that returns a `net.Conn`-like handle over an engine data path. |
|
||||
| `pkg/olcrtc/tunnel` | Embeddable server-side tunnel API with auth and traffic hooks. |
|
||||
| `mobile/mobile.go` | gomobile API for Android clients, including VPN socket protection. |
|
||||
| `script/srv.sh`, `script/cnc.sh` | Interactive shell launchers that generate YAML and run/build the app. |
|
||||
| `Dockerfile`, `script/docker/*` | Container build and server entrypoint/healthcheck. |
|
||||
|
||||
## Config And Session Layer
|
||||
|
||||
`internal/config` owns YAML parsing and file-backed secret loading.
|
||||
|
||||
Important fields:
|
||||
|
||||
| YAML | Runtime field | Notes |
|
||||
|---|---|---|
|
||||
| `mode` | `session.Config.Mode` | `srv`, `cnc`, or `gen`. |
|
||||
| `auth.provider` | `Auth` | `jitsi`, `telemost`, `jazz`, `wbstream`, or `none`. |
|
||||
| `room.id` | `RoomID` | Carrier-specific room reference. |
|
||||
| `crypto.key` / `crypto.key_file` | `KeyHex` | Shared 32-byte key encoded as 64 hex chars. |
|
||||
| `net.transport` | `Transport` | `datachannel`, `vp8channel`, `seichannel`, or `videochannel`. |
|
||||
| `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. |
|
||||
| `socks.*` | SOCKS fields | Client listener and optional server egress proxy. |
|
||||
| `engine.*` | direct engine fields | Used only with `auth.provider: none`. |
|
||||
| `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. |
|
||||
| `lifecycle.*` | session lifecycle | Planned call/session rotation. |
|
||||
| `traffic.*` | send shaping | Encrypted wire-message size cap and optional pacing jitter. |
|
||||
|
||||
`internal/app/session` is the main router:
|
||||
|
||||
1. Registers built-ins via `RegisterDefaults`.
|
||||
2. Applies auth defaults: auth provider decides engine and default service URL.
|
||||
3. Applies transport defaults: documented defaults for `vp8`, `sei`, and `video`.
|
||||
4. Validates mode, auth, link, transport, room, key, DNS, transport options, and SOCKS listener safety.
|
||||
5. Runs `server.Run`, `client.Run`, or `Gen`.
|
||||
|
||||
## Server Side
|
||||
|
||||
`internal/server` accepts encrypted smux sessions from the peer and proxies
|
||||
each smux stream to a TCP target.
|
||||
|
||||
Core pieces:
|
||||
|
||||
| Symbol | Role |
|
||||
|---|---|
|
||||
| `server.Run` | Creates cipher, link, smux server, and serve loop. |
|
||||
| `bringUpLink` | Builds `link.Link`, wires reconnect callbacks, connects carrier. |
|
||||
| `installSession` / `reinstallSession` | Creates or replaces `muxconn + smux.Session`. |
|
||||
| `acceptHandshake` | First smux stream; runs `handshake.Server`. |
|
||||
| `handleStream` | Reads connect JSON and dispatches a tunnel stream. |
|
||||
| `dispatch` | Dials target, sends ready byte, copies both directions. |
|
||||
| `AuthHook` | Embedders can authorize clients after `CLIENT_HELLO`. |
|
||||
| `OnSessionOpen`, `OnSessionClose`, `OnTraffic` | Observability hooks. |
|
||||
|
||||
Server risk areas:
|
||||
|
||||
- Target dialing is powerful by design. Any real product wrapper should add
|
||||
an `AuthHook` and probably destination policy.
|
||||
- `defaultAuthHook` admits everyone who knows the room and key.
|
||||
- Reconnect rebuilds smux sessions; active streams are sacrificed.
|
||||
|
||||
## Client Side
|
||||
|
||||
`internal/client` exposes a local SOCKS5 listener and opens one smux stream
|
||||
per SOCKS CONNECT request.
|
||||
|
||||
Core pieces:
|
||||
|
||||
| Symbol | Role |
|
||||
|---|---|
|
||||
| `RunWithReady` | Starts link, opens smux client, listens on local SOCKS. |
|
||||
| `openControlStream` | First smux stream; runs `handshake.Client`. |
|
||||
| `handleSocks5` | SOCKS method negotiation and CONNECT parsing. |
|
||||
| `sendConnectRequest` | Sends server-side target JSON and waits for ready byte. |
|
||||
| `handleReconnect` | Rebuilds smux and control stream after carrier reconnect. |
|
||||
| `resolveDeviceID` | Optional persistent client identity for hooks. |
|
||||
|
||||
Client risk areas:
|
||||
|
||||
- A non-loopback SOCKS listener must require `socks.user` and `socks.pass`.
|
||||
- SOCKS credentials are simple static credentials, not a full account system.
|
||||
- Existing streams do not survive reconnect; new SOCKS connections can recover.
|
||||
|
||||
## Wire Protocol Above WebRTC
|
||||
|
||||
`internal/muxconn` adapts `link.Link` to `io.ReadWriteCloser`.
|
||||
|
||||
- Every smux write is encrypted with `internal/crypto`.
|
||||
- Every inbound link message is decrypted and appended to an internal byte buffer.
|
||||
- Bad AEAD frames are dropped.
|
||||
- `CanSend` provides backpressure before encrypting and sending.
|
||||
|
||||
`internal/crypto` uses XChaCha20-Poly1305 with a random nonce prepended to
|
||||
each ciphertext.
|
||||
|
||||
`internal/handshake` runs on the first smux stream:
|
||||
|
||||
```text
|
||||
CLIENT_HELLO { version, device_id, claims }
|
||||
SERVER_WELCOME { version, session_id }
|
||||
or
|
||||
SERVER_REJECT { version, reason }
|
||||
```
|
||||
|
||||
The handshake has a 64 KiB frame cap and a default 15 second timeout.
|
||||
|
||||
After handshake, `internal/control` keeps that same encrypted smux stream open
|
||||
and exchanges length-prefixed JSON control messages:
|
||||
|
||||
```text
|
||||
CONTROL_PING { version, seq, sent_unix_nano }
|
||||
CONTROL_PONG { version, seq, sent_unix_nano }
|
||||
```
|
||||
|
||||
Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and
|
||||
`liveness.failures: 3`. Missed pongs mark the smux session unhealthy and
|
||||
trigger a session rebuild/reconnect path.
|
||||
|
||||
Client and server runtimes also maintain a `control.Status` snapshot with
|
||||
session ID, last pong time, RTT, missed pongs, reconnect count, and unhealthy
|
||||
event count. Embedders can consume it through the client/server health
|
||||
callbacks.
|
||||
|
||||
## Registries And Plugin Shape
|
||||
|
||||
The universal-carrier refactor centers on small registries:
|
||||
|
||||
| Registry | Package | Registers |
|
||||
|---|---|---|
|
||||
| Auth providers | `internal/auth` | Service-specific credential and room creation flows. |
|
||||
| Engines | `internal/engine` | Wire-level SFU protocol implementations. |
|
||||
| Carriers | `internal/carrier` | Auth + engine adapters exposed as byte/video capability providers. |
|
||||
| Transports | `internal/transport` | Byte transport strategy over carrier primitives. |
|
||||
| Links | `internal/link` | Higher-level link abstraction; currently only `direct`. |
|
||||
|
||||
`internal/carrier/builtin` connects the auth and engine worlds:
|
||||
|
||||
```text
|
||||
carrier "wbstream" -> auth/wbstream -> engine/livekit
|
||||
carrier "jazz" -> auth/salutejazz -> engine/salutejazz
|
||||
carrier "telemost"-> auth/telemost -> engine/goolom
|
||||
carrier "jitsi" -> auth/jitsi -> engine/jitsi
|
||||
carrier "none" -> direct user-supplied engine/url/token
|
||||
```
|
||||
|
||||
## Auth Providers
|
||||
|
||||
| Provider | Engine | Room generation | Notes |
|
||||
|---|---|---:|---|
|
||||
| `jitsi` | `jitsi` | No | Parses host/room from a public or self-hosted Jitsi URL. No HTTP auth. |
|
||||
| `telemost` | `goolom` | No | Calls Telemost room-info flow and returns Goolom credentials. |
|
||||
| `wbstream` | `livekit` | Yes | Registers guest, optionally creates room, joins room, fetches LiveKit token. |
|
||||
| `jazz` / `salutejazz` | `salutejazz` | Yes | Creates or joins SaluteJazz room and returns room/password tuple. |
|
||||
| `none` | chosen by config | No | Direct engine mode for downstream tools or self-hosted SFUs. |
|
||||
|
||||
## Engines
|
||||
|
||||
Engines expose the low-level service/SFU protocol.
|
||||
|
||||
| Engine | Package | Byte stream | Video track | Main job |
|
||||
|---|---|---:|---:|---|
|
||||
| `livekit` | `internal/engine/livekit` | Yes | Yes | LiveKit SDK room, data packets, local/remote tracks, reconnect with credential refresh. |
|
||||
| `goolom` | `internal/engine/goolom` | Yes | Yes | Yandex Telemost/Goolom signaling, split publisher/subscriber peer connections, telemetry/keepalive. |
|
||||
| `jitsi` | `internal/engine/jitsi` | Yes | Best effort | Jitsi MUC/Jingle/colibri-ws plus optional video track negotiation. |
|
||||
| `salutejazz` | `internal/engine/salutejazz` | Yes | Yes | SaluteJazz WebSocket signaling and split media peer connections. |
|
||||
|
||||
Engine work is where most provider breakage and reconnect complexity lives.
|
||||
|
||||
## Transports
|
||||
|
||||
Transports decide how raw tunnel bytes are carried once the carrier provides
|
||||
either a byte stream or a video track.
|
||||
|
||||
| Transport | Primitive | Reliability model | Best fit | Notes |
|
||||
|---|---|---|---|---|
|
||||
| `datachannel` | Carrier byte stream | Native reliable ordered messages | Jitsi, direct engines, some Jazz cases | Simple pass-through with 12 KiB message cap. |
|
||||
| `vp8channel` | VP8 video track | KCP over VP8-looking frames | WB Stream and Telemost-style video paths | Highest-performance video-path transport. Uses epochs and binding tokens to survive restarts/loopback. |
|
||||
| `seichannel` | H264 SEI video track | Custom fragments + ACK/retry | WB Stream fallback | Carries data in SEI NAL units with fragmentation, CRC, ACK. |
|
||||
| `videochannel` | Visual frames via ffmpeg | QR/tile frames + ACK/retry | Experimental/inspection-friendly path | Encodes visual payload frames, requires ffmpeg, supports QR and tile codecs. |
|
||||
|
||||
Transport work is where throughput, loss recovery, and adaptive tuning should
|
||||
happen.
|
||||
|
||||
## Public/Embedding Surfaces
|
||||
|
||||
| Package | User |
|
||||
|---|---|
|
||||
| `pkg/olcrtc` | Go programs that want a `net.Conn` over a selected auth/engine. |
|
||||
| `pkg/olcrtc/tunnel` | Go programs that want to embed the server-side tunnel with auth/traffic hooks. |
|
||||
| `mobile` | Android app bindings. Wraps client mode, VPN socket protection, logging, simple health checks. |
|
||||
| `cmd/olcrtc-cgo` | Native desktop/client integrations using c-shared Go export. |
|
||||
|
||||
These surfaces are important if the CLI becomes only one frontend among many.
|
||||
|
||||
## Tests
|
||||
|
||||
The project has broad unit coverage:
|
||||
|
||||
- Config/session validation and defaults.
|
||||
- Auth provider HTTP flows with test servers.
|
||||
- Engine helper logic and reconnect paths.
|
||||
- SOCKS parsing, smux handshake, server dispatch.
|
||||
- Crypto, muxconn, names, protect, logging.
|
||||
- Transport frame codecs, ACK paths, KCP loopback, ffmpeg helpers.
|
||||
- Memory-backed E2E tunnel tests and optional real-provider E2E matrix.
|
||||
|
||||
Useful commands:
|
||||
|
||||
```sh
|
||||
go test -count=1 ./...
|
||||
go test -race -count=1 ./cmd/olcrtc ./internal/app/session ./internal/config ./internal/engine/livekit
|
||||
go test -race -count=1 -v ./internal/e2e
|
||||
E2E_CARRIERS=wbstream E2E_TRANSPORTS=vp8channel mage e2e
|
||||
go build -trimpath -o build/olcrtc ./cmd/olcrtc
|
||||
```
|
||||
|
||||
## High-Value Coding Areas
|
||||
|
||||
### 1. Supervisor And Multi-Profile Failover
|
||||
|
||||
The first supervisor layer exists in `internal/supervisor`: the CLI can run a
|
||||
prioritized list of carrier/transport profiles and move to the next profile
|
||||
when the active one fails or ends.
|
||||
|
||||
```yaml
|
||||
mode: srv
|
||||
link: direct
|
||||
crypto:
|
||||
key_file: ./olcrtc.key
|
||||
net:
|
||||
dns: "1.1.1.1:53"
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: WB_ROOM_ID
|
||||
net:
|
||||
transport: vp8channel
|
||||
- name: jitsi-dc
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: https://meet.example.org/olcrtc-room
|
||||
net:
|
||||
transport: datachannel
|
||||
failover:
|
||||
retry_delay: 2s
|
||||
max_cycles: 0
|
||||
```
|
||||
|
||||
Implemented:
|
||||
|
||||
- Config schema for `profiles[]`.
|
||||
- Ordered supervisor loop.
|
||||
- `failover.retry_delay`.
|
||||
- `failover.max_cycles`.
|
||||
- Profile start/end logs.
|
||||
- Planned session rotation with `lifecycle.max_session_duration`.
|
||||
- Shared supervisor status snapshots with bounded failover history.
|
||||
- Shared traffic wrapper with payload cap, pacing jitter, and smux frame sizing.
|
||||
|
||||
Still valuable:
|
||||
|
||||
- Health scoring per profile.
|
||||
- Control-stream coordination before switching.
|
||||
- Stream draining and migration instead of dropping active smux streams.
|
||||
- User-facing status endpoint/export for the active profile and failover history.
|
||||
|
||||
Likely files:
|
||||
|
||||
- `internal/config/config.go`
|
||||
- `internal/app/session/session.go`
|
||||
- `internal/supervisor`
|
||||
- `internal/server`
|
||||
- `internal/client`
|
||||
- `docs/configuration.md`
|
||||
- `internal/e2e/tunnel_test.go`
|
||||
|
||||
### 2. Transport Telemetry And Adaptive Tuning
|
||||
|
||||
Add metrics from transport to link/session:
|
||||
|
||||
- Send queue depth.
|
||||
- ACK latency.
|
||||
- Retries.
|
||||
- Reconnect count.
|
||||
- Dropped/decrypt-failed frames.
|
||||
- KCP RTT/loss where available.
|
||||
|
||||
Then make `vp8.batch_size`, `sei.fragment_size`, ACK timeout, and pacing
|
||||
adaptive instead of static YAML knobs.
|
||||
|
||||
### 3. Control Stream Protocol
|
||||
|
||||
The first smux stream now carries control ping/pong after handshake. It is
|
||||
still the natural place for:
|
||||
|
||||
- Server policy updates.
|
||||
- Graceful reconnect notifications.
|
||||
- Drain/start markers for failover.
|
||||
- More per-session stats.
|
||||
|
||||
Likely files:
|
||||
|
||||
- `internal/control`
|
||||
- `internal/server`
|
||||
- `internal/client`
|
||||
|
||||
### 4. Destination Policy And Real Auth
|
||||
|
||||
The tunnel can dial arbitrary server-side TCP targets. A production wrapper
|
||||
should use `AuthHook` and enforce:
|
||||
|
||||
- Allowed destination CIDRs/domains/ports.
|
||||
- Per-device or per-plan policy.
|
||||
- Session expiration.
|
||||
- Traffic accounting limits.
|
||||
- Sanitized rejection reasons.
|
||||
|
||||
This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`.
|
||||
|
||||
### 5. Provider Hardening
|
||||
|
||||
Provider APIs can drift. Worth adding:
|
||||
|
||||
- Central protected HTTP/WebSocket client creation with TLS 1.2+,
|
||||
environment proxy support, HTTP/2 for HTTP, and bounded timeouts.
|
||||
- Better typed errors from auth providers.
|
||||
- Provider health probes.
|
||||
- Fixture-based contract tests for API response changes.
|
||||
- Per-provider rate/backoff policy.
|
||||
- Safer secret/log redaction.
|
||||
|
||||
Likely files:
|
||||
|
||||
- `internal/auth/*`
|
||||
- `internal/engine/*`
|
||||
- `internal/carrier/builtin`
|
||||
|
||||
### 6. Codebase Hygiene
|
||||
|
||||
Some public-facing text and comments are not suitable for a serious external
|
||||
project. Cleaning that up would improve maintainability and downstream trust.
|
||||
The most obvious targets are top-level docs and a large hostile block comment
|
||||
in `internal/transport/vp8channel/transport.go`.
|
||||
|
||||
## Where To Look First
|
||||
|
||||
| Goal | Start here |
|
||||
|---|---|
|
||||
| Change YAML schema | `internal/config/config.go`, `cmd/olcrtc/main.go`, docs examples. |
|
||||
| Change validation/defaults | `internal/app/session/session.go`. |
|
||||
| Add a new auth provider | `internal/auth`, then register in `internal/carrier/builtin/register.go`. |
|
||||
| Add a new SFU protocol | `internal/engine`, then connect through auth/carrier. |
|
||||
| Add a new byte transport | `internal/transport`, then register in `session.RegisterDefaults`. |
|
||||
| Add link behavior above transports | `internal/link`; currently only `direct`. |
|
||||
| Improve SOCKS behavior | `internal/client`. |
|
||||
| Improve server target dialing or policy | `internal/server`, `pkg/olcrtc/tunnel`. |
|
||||
| Improve reconnect | Engines first, then `internal/client` and `internal/server` smux rebuild behavior. |
|
||||
| Improve Android app integration | `mobile`, `internal/protect`, `client.RunWithReady`. |
|
||||
|
||||
## Mental Model For Big Changes
|
||||
|
||||
Prefer to keep the layer boundaries:
|
||||
|
||||
- Auth creates credentials; it should not know transport details.
|
||||
- Engine speaks service/SFU protocol; it should not know SOCKS or smux.
|
||||
- Carrier adapts auth+engine into byte/video capabilities.
|
||||
- Transport turns byte/video capabilities into reliable-ish tunnel bytes.
|
||||
- Link is policy above transport.
|
||||
- Client/server own SOCKS, smux, handshake, target dialing, and session hooks.
|
||||
|
||||
If a change crosses more than two layers, it probably deserves a new
|
||||
orchestrator package instead of pushing more state into an engine or transport.
|
||||
@@ -16,12 +16,28 @@ room:
|
||||
|
||||
crypto:
|
||||
# 32-byte hex (64 chars). Generate with: openssl rand -hex 32
|
||||
# Or use key_file: "./olcrtc.key" to keep the secret out of this file.
|
||||
key: "REPLACE_ME_WITH_64_HEX_CHARS"
|
||||
|
||||
net:
|
||||
transport: datachannel # datachannel | videochannel | seichannel | vp8channel
|
||||
dns: "8.8.8.8:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
# Optional planned rebuild for long-running calls.
|
||||
# lifecycle:
|
||||
# max_session_duration: 6h
|
||||
|
||||
# Optional reliability shaping for encrypted wire messages.
|
||||
# traffic:
|
||||
# max_payload_size: 4096
|
||||
# min_delay: 5ms
|
||||
# max_delay: 30ms
|
||||
|
||||
# Outbound SOCKS5 proxy for server-side egress (optional)
|
||||
socks:
|
||||
proxy_addr: "" # e.g. "127.0.0.1"
|
||||
|
||||
@@ -48,7 +48,7 @@
|
||||
| `auth.provider` | `telemost`, `jazz`, `wbstream` или `jitsi` |
|
||||
| `net.transport` | `datachannel`, `vp8channel`, `seichannel` или `videochannel` |
|
||||
| `room.id` | Room ID |
|
||||
| `crypto.key` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` |
|
||||
| `crypto.key` или `crypto.key_file` | Ключ шифрования hex 64 символа. Генерация: `openssl rand -hex 32` |
|
||||
| `link` | Всегда `direct` |
|
||||
| `data` | Всегда `data` |
|
||||
| `net.dns` | DNS-сервер, например `1.1.1.1:53` |
|
||||
@@ -60,18 +60,52 @@
|
||||
| YAML поле | Описание |
|
||||
|-----------|----------|
|
||||
| `debug` | `true` для подробных логов соединений |
|
||||
| `profiles` | Список профилей failover для `srv`/`cnc` |
|
||||
| `failover.retry_delay` | Пауза перед следующим профилем, например `2s` |
|
||||
| `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно |
|
||||
| `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` |
|
||||
| `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` |
|
||||
| `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` |
|
||||
| `lifecycle.max_session_duration` | Плановый rebuild сессии после указанного времени, например `6h`; если поле не задано, выключено |
|
||||
| `traffic.max_payload_size` | Лимит размера зашифрованного wire-message; `0` = лимит транспорта |
|
||||
| `traffic.min_delay` / `.max_delay` | Необязательный pacing отправки, например `5ms` / `30ms` |
|
||||
|
||||
`crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно.
|
||||
|
||||
Если задан `profiles`, поля верхнего уровня становятся общими defaults, а
|
||||
каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и
|
||||
настройки транспорта/liveness. Порядок профилей должен совпадать на сервере и
|
||||
клиенте.
|
||||
|
||||
`liveness` проверяет именно зашифрованный smux control stream после handshake,
|
||||
а не только статус WebRTC/provider соединения. Если pong не приходит несколько
|
||||
раз подряд, текущая smux-сессия пересоздается.
|
||||
|
||||
`lifecycle.max_session_duration` ограничивает длительность одного звонка /
|
||||
provider session. Когда таймер истекает, текущая `srv` или `cnc` сессия
|
||||
закрывается и стартует заново с тем же конфигом. Пока эта настройка включена,
|
||||
чистое завершение сессии тоже перезапускается, чтобы второй peer мог догнать
|
||||
плановый rebuild. Формат значения: `30m`, `2h`, `6h`; `0s` и отрицательные
|
||||
значения не принимаются.
|
||||
|
||||
`traffic` добавляет общий wrapper над выбранным transport. Он может ограничить
|
||||
размер зашифрованного сообщения и добавить небольшую задержку перед отправкой.
|
||||
Данные не обрезаются: если сообщение не помещается в эффективный лимит, send
|
||||
возвращает явную ошибку. При заданном `max_payload_size` smux frame size также
|
||||
уменьшается с учетом crypto overhead; при `0` остается лимит выбранного
|
||||
transport. Используй одинаковые traffic-настройки на обеих сторонах.
|
||||
|
||||
---
|
||||
|
||||
## mode: gen
|
||||
|
||||
Генерирует Room ID заранее, не запуская сервер. Поддерживается только для `jazz`. Для `wbstream` создавай руму вручную через [stream.wb.ru](https://stream.wb.ru) (автогенерация отключена со стороны WB).
|
||||
Генерирует Room ID заранее, не запуская сервер. Поддерживается для auth-провайдеров с автосозданием комнат: `jazz` и `wbstream`. Для `telemost` комнату нужно создавать вручную через сайт.
|
||||
|
||||
**Обязательные поля:**
|
||||
|
||||
| YAML поле | Описание |
|
||||
|-----------|----------|
|
||||
| `auth.provider` | `jazz` |
|
||||
| `auth.provider` | `jazz` или `wbstream` |
|
||||
| `net.dns` | DNS-сервер |
|
||||
| `gen.amount` | Количество комнат |
|
||||
|
||||
@@ -79,7 +113,7 @@
|
||||
# gen.yaml
|
||||
mode: gen
|
||||
auth:
|
||||
provider: jazz
|
||||
provider: wbstream
|
||||
net:
|
||||
dns: "1.1.1.1:53"
|
||||
gen:
|
||||
@@ -116,6 +150,9 @@ gen:
|
||||
Если `socks.user` не задан - аутентификация отключена (любой локальный клиент может подключиться).
|
||||
Если задан - клиент принимает только подключения с правильным логином и паролем (RFC 1929).
|
||||
|
||||
Если `socks.host` не loopback (`127.0.0.1`, `::1`, `localhost`), `socks.user` и `socks.pass` обязательны.
|
||||
Это защита от случайного открытого SOCKS5-прокси в локальной сети или интернете.
|
||||
|
||||
---
|
||||
|
||||
## datachannel
|
||||
|
||||
@@ -5,13 +5,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/auth"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier/builtin"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link/direct"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
@@ -37,18 +41,35 @@ const (
|
||||
videoCodecTile = "tile"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultVideoWidth = 1920
|
||||
defaultVideoHeight = 1080
|
||||
defaultVideoFPS = 30
|
||||
defaultVideoBitrate = "2M"
|
||||
defaultVideoHW = "none"
|
||||
defaultVideoQRRecovery = "low"
|
||||
defaultVP8FPS = 25
|
||||
defaultVP8BatchSize = 1
|
||||
defaultSEIFPS = 60
|
||||
defaultSEIBatchSize = 64
|
||||
defaultSEIFragmentSize = 900
|
||||
defaultSEIAckTimeoutMS = 2000
|
||||
)
|
||||
|
||||
var sessionRestartDelay = 2 * time.Second
|
||||
|
||||
var (
|
||||
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
|
||||
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
|
||||
ErrRoomIDRequired = errors.New("room ID required (set room.id)")
|
||||
// ErrModeRequired indicates that mode is not one of the supported values.
|
||||
ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)")
|
||||
// ErrAmountRequired indicates that -amount is required for gen mode.
|
||||
ErrAmountRequired = errors.New("amount required for gen mode (use -amount <n>)")
|
||||
ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)")
|
||||
// ErrAmountRequired indicates that gen.amount is required for gen mode.
|
||||
ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)")
|
||||
// ErrAuthRequired indicates that no auth provider was selected.
|
||||
ErrAuthRequired = errors.New(
|
||||
"auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)")
|
||||
// ErrURLRequired indicates that -url must be provided when the auth provider has no default URL.
|
||||
ErrURLRequired = errors.New("SFU URL required (use -url wss://...)")
|
||||
"auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)")
|
||||
// ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL.
|
||||
ErrURLRequired = errors.New("SFU URL required (set auth.url)")
|
||||
// ErrUnsupportedCarrier indicates that carrier is not registered.
|
||||
ErrUnsupportedCarrier = errors.New("unsupported carrier")
|
||||
// ErrUnsupportedLink indicates that link is not registered.
|
||||
@@ -57,88 +78,119 @@ var (
|
||||
ErrUnsupportedTransport = errors.New("unsupported transport")
|
||||
|
||||
// ErrLinkRequired indicates that link is not provided.
|
||||
ErrLinkRequired = errors.New("link required (use -link direct)")
|
||||
ErrLinkRequired = errors.New("link required (set link to direct)")
|
||||
// ErrTransportRequired indicates that transport is not provided.
|
||||
ErrTransportRequired = errors.New(
|
||||
"transport required (use -transport datachannel, -transport videochannel, " +
|
||||
"-transport seichannel or -transport vp8channel)")
|
||||
"transport required (set transport to datachannel, videochannel, seichannel or vp8channel)")
|
||||
// ErrKeyRequired indicates that encryption key is not provided.
|
||||
ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
ErrKeyRequired = errors.New("key required (set crypto.key)")
|
||||
// ErrDNSServerRequired indicates that dns server is not provided.
|
||||
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
|
||||
ErrDNSServerRequired = errors.New("dns server required (set net.dns)")
|
||||
|
||||
// ErrVideoWidthRequired indicates that video width is required for videochannel.
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)")
|
||||
// ErrVideoHeightRequired indicates that video height is required for videochannel.
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)")
|
||||
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)")
|
||||
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
|
||||
ErrVideoBitrateRequired = errors.New(
|
||||
"video bitrate required for videochannel (use -video-bitrate)")
|
||||
"video bitrate required for videochannel (set video.bitrate)")
|
||||
// ErrVideoHWRequired indicates that video hardware acceleration is required.
|
||||
ErrVideoHWRequired = errors.New(
|
||||
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
|
||||
"video hardware acceleration required for videochannel (set video.hw to none or nvenc)")
|
||||
// ErrVideoCodecInvalid indicates that the video codec is not valid.
|
||||
ErrVideoCodecInvalid = errors.New(
|
||||
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
|
||||
"invalid video codec for videochannel (set video.codec to qrcode or tile)")
|
||||
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
|
||||
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
|
||||
ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080")
|
||||
|
||||
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)")
|
||||
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
|
||||
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
|
||||
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)")
|
||||
// ErrSEIFPSRequired indicates that seichannel fps is required.
|
||||
ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)")
|
||||
ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)")
|
||||
// ErrSEIBatchSizeRequired indicates that seichannel batch size is required.
|
||||
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)")
|
||||
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)")
|
||||
// ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required.
|
||||
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)")
|
||||
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)")
|
||||
// ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required.
|
||||
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)")
|
||||
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)")
|
||||
|
||||
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)")
|
||||
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)")
|
||||
// ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication.
|
||||
ErrSOCKSAuthRequired = errors.New(
|
||||
"socks auth required when binding outside loopback (set socks.user and socks.pass)")
|
||||
|
||||
// ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration.
|
||||
ErrLivenessIntervalInvalid = errors.New(
|
||||
"invalid liveness interval (set liveness.interval to a duration > 0)")
|
||||
// ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration.
|
||||
ErrLivenessTimeoutInvalid = errors.New(
|
||||
"invalid liveness timeout (set liveness.timeout to a duration > 0)")
|
||||
// ErrLivenessFailuresInvalid indicates that liveness.failures is not positive.
|
||||
ErrLivenessFailuresInvalid = errors.New(
|
||||
"invalid liveness failures (set liveness.failures to a value > 0)")
|
||||
// ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration.
|
||||
ErrLifecycleMaxSessionDurationInvalid = errors.New(
|
||||
"invalid max session duration (set lifecycle.max_session_duration to a duration > 0)")
|
||||
// ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid.
|
||||
ErrTrafficMaxPayloadSizeInvalid = errors.New(
|
||||
"invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)")
|
||||
// ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration.
|
||||
ErrTrafficMinDelayInvalid = errors.New(
|
||||
"invalid traffic min delay (set traffic.min_delay to a duration >= 0)")
|
||||
// ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration.
|
||||
ErrTrafficMaxDelayInvalid = errors.New(
|
||||
"invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)")
|
||||
)
|
||||
|
||||
// Config holds runtime session settings.
|
||||
type Config struct {
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Amount int
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
LivenessInterval string
|
||||
LivenessTimeout string
|
||||
LivenessFailures int
|
||||
MaxSessionDuration string
|
||||
TrafficMaxPayloadSize int
|
||||
TrafficMinDelay string
|
||||
TrafficMaxDelay string
|
||||
Amount int
|
||||
}
|
||||
|
||||
// RegisterDefaults registers built-in carriers and transports.
|
||||
@@ -180,6 +232,94 @@ func ApplyAuthDefaults(cfg Config) (Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ApplyTransportDefaults fills documented transport defaults without changing core routing fields.
|
||||
func ApplyTransportDefaults(cfg Config) Config {
|
||||
switch cfg.Transport {
|
||||
case transportVideo:
|
||||
return applyVideoDefaults(cfg)
|
||||
case transportVP8:
|
||||
return applyVP8Defaults(cfg)
|
||||
case transportSEI:
|
||||
return applySEIDefaults(cfg)
|
||||
default:
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyLivenessDefaults fills documented control-stream liveness defaults.
|
||||
func ApplyLivenessDefaults(cfg Config) Config {
|
||||
if cfg.LivenessInterval == "" {
|
||||
cfg.LivenessInterval = control.DefaultInterval.String()
|
||||
}
|
||||
if cfg.LivenessTimeout == "" {
|
||||
cfg.LivenessTimeout = control.DefaultTimeout.String()
|
||||
}
|
||||
if cfg.LivenessFailures == 0 {
|
||||
cfg.LivenessFailures = control.DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyVideoDefaults(cfg Config) Config {
|
||||
if cfg.VideoCodec == "" {
|
||||
cfg.VideoCodec = videoCodecQRCode
|
||||
}
|
||||
if cfg.VideoCodec == videoCodecTile {
|
||||
if cfg.VideoWidth == 0 {
|
||||
cfg.VideoWidth = 1080
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
cfg.VideoHeight = 1080
|
||||
}
|
||||
} else {
|
||||
if cfg.VideoWidth == 0 {
|
||||
cfg.VideoWidth = defaultVideoWidth
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
cfg.VideoHeight = defaultVideoHeight
|
||||
}
|
||||
}
|
||||
if cfg.VideoFPS == 0 {
|
||||
cfg.VideoFPS = defaultVideoFPS
|
||||
}
|
||||
if cfg.VideoBitrate == "" {
|
||||
cfg.VideoBitrate = defaultVideoBitrate
|
||||
}
|
||||
if cfg.VideoHW == "" {
|
||||
cfg.VideoHW = defaultVideoHW
|
||||
}
|
||||
if cfg.VideoQRRecovery == "" {
|
||||
cfg.VideoQRRecovery = defaultVideoQRRecovery
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyVP8Defaults(cfg Config) Config {
|
||||
if cfg.VP8FPS == 0 {
|
||||
cfg.VP8FPS = defaultVP8FPS
|
||||
}
|
||||
if cfg.VP8BatchSize == 0 {
|
||||
cfg.VP8BatchSize = defaultVP8BatchSize
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applySEIDefaults(cfg Config) Config {
|
||||
if cfg.SEIFPS == 0 {
|
||||
cfg.SEIFPS = defaultSEIFPS
|
||||
}
|
||||
if cfg.SEIBatchSize == 0 {
|
||||
cfg.SEIBatchSize = defaultSEIBatchSize
|
||||
}
|
||||
if cfg.SEIFragmentSize == 0 {
|
||||
cfg.SEIFragmentSize = defaultSEIFragmentSize
|
||||
}
|
||||
if cfg.SEIAckTimeoutMS == 0 {
|
||||
cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Validate verifies that the runtime config refers to registered components and all required fields are present.
|
||||
func Validate(cfg Config) error {
|
||||
if err := validateMode(cfg); err != nil {
|
||||
@@ -200,6 +340,15 @@ func Validate(cfg Config) error {
|
||||
if err := validateTransportConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateLivenessConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateLifecycleConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTrafficConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateModeConfig(cfg)
|
||||
}
|
||||
|
||||
@@ -333,13 +482,163 @@ func validateModeConfig(cfg Config) error {
|
||||
if cfg.SOCKSPort == 0 {
|
||||
return ErrSOCKSPortRequired
|
||||
}
|
||||
if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") {
|
||||
return ErrSOCKSAuthRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLivenessConfig(cfg Config) error {
|
||||
if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
if cfg.LivenessFailures < 0 {
|
||||
return ErrLivenessFailuresInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLifecycleConfig(cfg Config) error {
|
||||
if _, err := maxSessionDuration(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) {
|
||||
if value == "" {
|
||||
return def, nil
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be > 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func livenessConfig(cfg Config) (control.Config, error) {
|
||||
interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
failures := cfg.LivenessFailures
|
||||
if failures == 0 {
|
||||
failures = control.DefaultFailures
|
||||
}
|
||||
if failures < 0 {
|
||||
return control.Config{}, ErrLivenessFailuresInvalid
|
||||
}
|
||||
return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil
|
||||
}
|
||||
|
||||
func maxSessionDuration(cfg Config) (time.Duration, error) {
|
||||
if cfg.MaxSessionDuration == "" {
|
||||
return 0, nil
|
||||
}
|
||||
d, err := time.ParseDuration(cfg.MaxSessionDuration)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %v", ErrLifecycleMaxSessionDurationInvalid, err)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, ErrLifecycleMaxSessionDurationInvalid
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func validateTrafficConfig(cfg Config) error {
|
||||
_, err := trafficConfig(cfg)
|
||||
return err
|
||||
}
|
||||
|
||||
func trafficConfig(cfg Config) (transport.TrafficConfig, error) {
|
||||
if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 &&
|
||||
cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) {
|
||||
return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid
|
||||
}
|
||||
minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay)
|
||||
if err != nil {
|
||||
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err)
|
||||
}
|
||||
maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay)
|
||||
if err != nil {
|
||||
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err)
|
||||
}
|
||||
if maxDelay > 0 && maxDelay < minDelay {
|
||||
return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid
|
||||
}
|
||||
return transport.TrafficConfig{
|
||||
MaxPayloadSize: cfg.TrafficMaxPayloadSize,
|
||||
MinDelay: minDelay,
|
||||
MaxDelay: maxDelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOptionalNonNegativeDuration(value string) (time.Duration, error) {
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, fmt.Errorf("duration must be >= 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func isLoopbackListenHost(host string) bool {
|
||||
if host == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
// Run starts the configured mode.
|
||||
func Run(ctx context.Context, cfg Config) error {
|
||||
cfg = ApplyTransportDefaults(cfg)
|
||||
cfg = ApplyLivenessDefaults(cfg)
|
||||
roomURL := cfg.RoomID
|
||||
liveness, err := livenessConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
maxDuration, err := maxSessionDuration(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
traffic, err := trafficConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
run := func(ctx context.Context) error {
|
||||
return runOnce(ctx, cfg, roomURL, liveness, traffic)
|
||||
}
|
||||
if maxDuration > 0 {
|
||||
return runWithSessionRotation(ctx, maxDuration, run)
|
||||
}
|
||||
return run(ctx)
|
||||
}
|
||||
|
||||
func runOnce(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
roomURL string,
|
||||
liveness control.Config,
|
||||
traffic transport.TrafficConfig,
|
||||
) error {
|
||||
switch cfg.Mode {
|
||||
case modeSRV:
|
||||
if err := server.Run(ctx, server.Config{
|
||||
@@ -370,6 +669,8 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
Traffic: traffic,
|
||||
OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) {
|
||||
logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims)
|
||||
},
|
||||
@@ -413,6 +714,8 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
Traffic: traffic,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("client: %w", err)
|
||||
}
|
||||
@@ -422,6 +725,52 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
func runWithSessionRotation(ctx context.Context, maxDuration time.Duration, run func(context.Context) error) error {
|
||||
for cycle := 1; ; cycle++ {
|
||||
currentCycle := cycle
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
var rotated atomic.Bool
|
||||
timer := time.AfterFunc(maxDuration, func() {
|
||||
rotated.Store(true)
|
||||
logger.Infof("session max duration reached: duration=%s cycle=%d", maxDuration, currentCycle)
|
||||
cancel()
|
||||
})
|
||||
|
||||
err := run(runCtx)
|
||||
cancel()
|
||||
timer.Stop()
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if rotated.Load() {
|
||||
if err != nil {
|
||||
logger.Warnf("session rotation ended with error: cycle=%d err=%v", currentCycle, err)
|
||||
}
|
||||
logger.Infof("session rotation restarting: next_cycle=%d", currentCycle+1)
|
||||
if err := waitSessionRestart(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("session ended cleanly with lifecycle rotation enabled: next_cycle=%d", currentCycle+1)
|
||||
if err := waitSessionRestart(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitSessionRestart(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(sessionRestartDelay):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateGen validates that the config contains enough fields to run gen mode.
|
||||
func ValidateGen(cfg Config) error {
|
||||
if cfg.Auth == "" {
|
||||
|
||||
@@ -3,9 +3,136 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
)
|
||||
|
||||
func TestApplyTransportDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in Config
|
||||
want Config
|
||||
}{
|
||||
{
|
||||
name: "vp8",
|
||||
in: Config{Transport: transportVP8},
|
||||
want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1},
|
||||
},
|
||||
{
|
||||
name: "sei",
|
||||
in: Config{Transport: transportSEI},
|
||||
want: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 60,
|
||||
SEIBatchSize: 64,
|
||||
SEIFragmentSize: 900,
|
||||
SEIAckTimeoutMS: 2000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "video qrcode",
|
||||
in: Config{Transport: transportVideo},
|
||||
want: Config{
|
||||
Transport: transportVideo,
|
||||
VideoWidth: 1920,
|
||||
VideoHeight: 1080,
|
||||
VideoFPS: 30,
|
||||
VideoBitrate: "2M",
|
||||
VideoHW: "none",
|
||||
VideoQRRecovery: "low",
|
||||
VideoCodec: videoCodecQRCode,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "video tile dimensions",
|
||||
in: Config{Transport: transportVideo, VideoCodec: videoCodecTile},
|
||||
want: Config{
|
||||
Transport: transportVideo,
|
||||
VideoWidth: 1080,
|
||||
VideoHeight: 1080,
|
||||
VideoFPS: 30,
|
||||
VideoBitrate: "2M",
|
||||
VideoHW: "none",
|
||||
VideoQRRecovery: "low",
|
||||
VideoCodec: videoCodecTile,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keeps explicit values",
|
||||
in: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 10,
|
||||
SEIBatchSize: 2,
|
||||
SEIFragmentSize: 300,
|
||||
SEIAckTimeoutMS: 1500,
|
||||
},
|
||||
want: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 10,
|
||||
SEIBatchSize: 2,
|
||||
SEIFragmentSize: 300,
|
||||
SEIAckTimeoutMS: 1500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ApplyTransportDefaults(tt.in)
|
||||
if got != tt.want {
|
||||
t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLivenessDefaults(t *testing.T) {
|
||||
got := ApplyLivenessDefaults(Config{})
|
||||
if got.LivenessInterval != control.DefaultInterval.String() {
|
||||
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
|
||||
}
|
||||
if got.LivenessTimeout != control.DefaultTimeout.String() {
|
||||
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
|
||||
}
|
||||
if got.LivenessFailures != control.DefaultFailures {
|
||||
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
|
||||
}
|
||||
|
||||
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
|
||||
if got := ApplyLivenessDefaults(explicit); got != explicit {
|
||||
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) {
|
||||
oldRestartDelay := sessionRestartDelay
|
||||
sessionRestartDelay = time.Millisecond
|
||||
t.Cleanup(func() { sessionRestartDelay = oldRestartDelay })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var calls atomic.Int32
|
||||
err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error {
|
||||
if calls.Add(1) >= 2 {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runWithSessionRotation() error = %v", err)
|
||||
}
|
||||
if got := calls.Load(); got < 2 {
|
||||
t.Fatalf("run calls = %d, want at least 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:maintidx // table-driven validation test naturally has many cases
|
||||
func TestValidate(t *testing.T) {
|
||||
RegisterDefaults()
|
||||
@@ -310,6 +437,148 @@ func TestValidate(t *testing.T) {
|
||||
}(),
|
||||
want: ErrSOCKSPortRequired,
|
||||
},
|
||||
{
|
||||
name: "cnc rejects unauthenticated wildcard socks bind",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "0.0.0.0"
|
||||
cfg.SOCKSPort = 1080
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrSOCKSAuthRequired,
|
||||
},
|
||||
{
|
||||
name: "cnc allows authenticated wildcard socks bind",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "0.0.0.0"
|
||||
cfg.SOCKSPort = 1080
|
||||
cfg.SOCKSUser = "user"
|
||||
cfg.SOCKSPass = "pass"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "cnc allows localhost socks bind without auth",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "localhost"
|
||||
cfg.SOCKSPort = 1080
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "liveness rejects bad interval",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessInterval = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessIntervalInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects zero timeout",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessTimeout = "0s"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessTimeoutInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects negative failures",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessFailures = -1
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessFailuresInvalid,
|
||||
},
|
||||
{
|
||||
name: "lifecycle accepts max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "1h"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "lifecycle rejects bad max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLifecycleMaxSessionDurationInvalid,
|
||||
},
|
||||
{
|
||||
name: "lifecycle rejects zero max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "0s"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLifecycleMaxSessionDurationInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic accepts shaping",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = 4096
|
||||
cfg.TrafficMinDelay = "5ms"
|
||||
cfg.TrafficMaxDelay = "30ms"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "traffic rejects negative max payload",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = -1
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxPayloadSizeInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects payload smaller than crypto overhead",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = crypto.WireOverhead
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxPayloadSizeInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects bad min delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMinDelay = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMinDelayInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects negative max delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxDelay = "-1ms"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxDelayInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects max delay below min delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMinDelay = "30ms"
|
||||
cfg.TrafficMaxDelay = "5ms"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxDelayInvalid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,9 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
@@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, statusError(errCreateRoomFailed, resp)
|
||||
return nil, protect.StatusError(errCreateRoomFailed, resp, 1024)
|
||||
}
|
||||
|
||||
var res createResponse
|
||||
@@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
|
||||
defer func() { _ = preResp.Body.Close() }()
|
||||
|
||||
if preResp.StatusCode != http.StatusOK {
|
||||
return "", statusError(errPreconnectFailed, preResp)
|
||||
return "", protect.StatusError(errPreconnectFailed, preResp, 1024)
|
||||
}
|
||||
|
||||
var preconnectResp struct {
|
||||
@@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
|
||||
return preconnectResp.ConnectorURL, nil
|
||||
}
|
||||
|
||||
func statusError(base error, resp *http.Response) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
bodyText := strings.TrimSpace(string(body))
|
||||
if bodyText == "" {
|
||||
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
|
||||
}
|
||||
|
||||
func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) {
|
||||
headers := anonymousHeaders()
|
||||
connectorURL, err := preconnect(ctx, roomID, password, headers)
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
@@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
|
||||
return nil, protect.StatusError(ErrAPI, resp, 4096)
|
||||
}
|
||||
|
||||
var info ConnectionInfo
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
@@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b)
|
||||
return "", protect.StatusError(errGuestRegister, resp, 4096)
|
||||
}
|
||||
|
||||
var res guestRegisterResponse
|
||||
@@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b)
|
||||
return "", protect.StatusError(errCreateRoom, resp, 4096)
|
||||
}
|
||||
|
||||
var res createRoomResponse
|
||||
@@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b)
|
||||
return protect.StatusError(errJoinRoom, resp, 4096)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b)
|
||||
return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096)
|
||||
}
|
||||
|
||||
var res tokenResponse
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
@@ -54,7 +56,12 @@ type Client struct {
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStrm *smux.Stream
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reconnectMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
health control.Status
|
||||
onHealth HealthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
claims map[string]any
|
||||
@@ -63,6 +70,9 @@ type Client struct {
|
||||
socksPass string
|
||||
}
|
||||
|
||||
// HealthFunc is called when the client control health snapshot changes.
|
||||
type HealthFunc func(control.Status)
|
||||
|
||||
// Config holds runtime configuration for [Run] and [RunWithReady].
|
||||
type Config struct {
|
||||
Link string
|
||||
@@ -93,6 +103,8 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
Traffic transport.TrafficConfig
|
||||
|
||||
// DeviceID overrides the persistent client-side device identifier. Leave
|
||||
// empty to derive one from DeviceIDPath (or generate a random one if both
|
||||
@@ -106,6 +118,9 @@ type Config struct {
|
||||
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
|
||||
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
|
||||
Claims map[string]any
|
||||
|
||||
// OnHealth receives liveness/reconnect status updates. Nil means no-op.
|
||||
OnHealth HealthFunc
|
||||
}
|
||||
|
||||
// Run starts the client with the given configuration.
|
||||
@@ -135,6 +150,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksUser: cfg.SOCKSUser,
|
||||
socksPass: cfg.SOCKSPass,
|
||||
onHealth: cfg.OnHealth,
|
||||
}
|
||||
|
||||
// shutdown is registered BEFORE bringUpLink so we always close any
|
||||
@@ -202,6 +218,7 @@ func (c *Client) bringUpLink(
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
@@ -217,7 +234,9 @@ func (c *Client) bringUpLink(
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
c.handleReconnect()
|
||||
if !c.handleReconnect(ctx, cfg, cancel, "carrier") {
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
|
||||
if err := ln.Connect(ctx); err != nil {
|
||||
@@ -225,7 +244,7 @@ func (c *Client) bringUpLink(
|
||||
}
|
||||
|
||||
c.conn = muxconn.New(ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("smux client: %w", err)
|
||||
}
|
||||
@@ -243,14 +262,16 @@ func (c *Client) bringUpLink(
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.recordSession(sid)
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
|
||||
go ln.WatchConnection(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openControlStream opens stream #1 on sess and performs the handshake.
|
||||
// The stream stays open for the lifetime of the smux session — the server
|
||||
// holds it parked, and it would carry future control messages.
|
||||
// The stream stays open for the lifetime of the smux session and carries
|
||||
// post-handshake control messages.
|
||||
func openControlStream(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
@@ -314,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) {
|
||||
}
|
||||
|
||||
// smuxConfig returns the tuned smux config used on both ends.
|
||||
func smuxConfig() *smux.Config {
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
@@ -326,8 +353,20 @@ func smuxConfig() *smux.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect() {
|
||||
logger.Infof("client link reconnect - tearing down smux session")
|
||||
func linkMaxPayload(ln link.Link) int {
|
||||
provider, ok := ln.(link.FeaturesProvider)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return provider.Features().MaxPayloadSize
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool {
|
||||
c.reconnectMu.Lock()
|
||||
defer c.reconnectMu.Unlock()
|
||||
|
||||
c.recordReconnect()
|
||||
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
|
||||
|
||||
// Install a fresh muxconn immediately so onData never hits nil while
|
||||
// the old session is being torn down. tryReopenSession will swap it
|
||||
@@ -336,14 +375,19 @@ func (c *Client) handleReconnect() {
|
||||
|
||||
c.sessMu.Lock()
|
||||
oldControl := c.controlStrm
|
||||
oldControlStop := c.controlStop
|
||||
oldSess := c.session
|
||||
oldConn := c.conn
|
||||
c.conn = newConn
|
||||
c.session = nil
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.sessionID = ""
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldControl != nil {
|
||||
_ = oldControl.Close()
|
||||
}
|
||||
@@ -364,15 +408,26 @@ func (c *Client) handleReconnect() {
|
||||
attemptDelay = 300 * time.Millisecond
|
||||
)
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if c.tryReopenSession(attempt) {
|
||||
return
|
||||
logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason)
|
||||
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(attemptDelay):
|
||||
}
|
||||
time.Sleep(attemptDelay)
|
||||
}
|
||||
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) tryReopenSession(attempt int) bool {
|
||||
func (c *Client) tryReopenSession(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
attempt int,
|
||||
) bool {
|
||||
conn := muxconn.New(c.ln, c.cipher)
|
||||
|
||||
c.sessMu.Lock()
|
||||
@@ -383,7 +438,7 @@ func (c *Client) tryReopenSession(attempt int) bool {
|
||||
_ = old.Close()
|
||||
}
|
||||
|
||||
sess, err := smux.Client(conn, smuxConfig())
|
||||
sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
|
||||
return false
|
||||
@@ -400,19 +455,138 @@ func (c *Client) tryReopenSession(attempt int) bool {
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.recordSession(sid)
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) startControlLoop(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
stream *smux.Stream,
|
||||
) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
c.sessMu.Lock()
|
||||
c.controlStop = stop
|
||||
c.sessMu.Unlock()
|
||||
|
||||
liveness := cfg.Liveness
|
||||
onPong := liveness.OnPong
|
||||
onMissedPong := liveness.OnMissedPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
c.sessMu.RLock()
|
||||
sid := c.sessionID
|
||||
c.sessMu.RUnlock()
|
||||
c.recordPong(h)
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnMissedPong = func(missed int) {
|
||||
c.recordMissed(missed)
|
||||
logger.Warnf("control missed pong on client: missed_pongs=%d", missed)
|
||||
if onMissedPong != nil {
|
||||
onMissedPong(missed)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
c.recordUnhealthy(missed)
|
||||
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("client control stream ended: %v", err)
|
||||
}
|
||||
if !c.handleReconnect(ctx, cfg, cancel, "liveness") {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Status returns the latest client-side control health snapshot.
|
||||
func (c *Client) Status() control.Status {
|
||||
c.healthMu.RLock()
|
||||
defer c.healthMu.RUnlock()
|
||||
return c.health
|
||||
}
|
||||
|
||||
func (c *Client) recordSession(sessionID string) {
|
||||
c.healthMu.Lock()
|
||||
c.health.SessionID = sessionID
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordPong(h control.Health) {
|
||||
c.healthMu.Lock()
|
||||
c.health.LastPong = h.LastSeen
|
||||
c.health.LastRTT = h.RTT
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordMissed(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordUnhealthy(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
c.health.UnhealthyEvents++
|
||||
c.health.LastUnhealthy = time.Now()
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordReconnect() {
|
||||
c.healthMu.Lock()
|
||||
c.health.Reconnects++
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) notifyHealth(status control.Status) {
|
||||
if c.onHealth != nil {
|
||||
c.onHealth(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
control := c.controlStrm
|
||||
controlStop := c.controlStop
|
||||
sess := c.session
|
||||
conn := c.conn
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.session = nil
|
||||
c.conn = nil
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -48,6 +49,11 @@ func TestSmuxConfig(t *testing.T) {
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
|
||||
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocks5Handshake(t *testing.T) {
|
||||
@@ -517,3 +523,96 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
|
||||
t.Fatal("shutdown() did not close link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
peerStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
peerStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
stream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
peerStream := <-peerStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
got := make(chan control.Health, 1)
|
||||
c := &Client{sessionID: "sid-control"}
|
||||
c.recordSession("sid-control")
|
||||
c.startControlLoop(ctx, Config{
|
||||
Liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}, cancel, stream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, peerStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
status := c.Status()
|
||||
if status.SessionID != "sid-control" {
|
||||
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
|
||||
}
|
||||
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
c := &Client{onHealth: func(control.Status) { updates++ }}
|
||||
c.recordSession("sid-1")
|
||||
c.recordMissed(2)
|
||||
c.recordUnhealthy(3)
|
||||
c.recordReconnect()
|
||||
|
||||
status := c.Status()
|
||||
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
|
||||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
if updates != 4 {
|
||||
t.Fatalf("health updates = %d, want 4", updates)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
// Package config loads olcrtc runtime configuration from YAML files.
|
||||
//
|
||||
// The YAML schema mirrors [session.Config]. Fields left unset in the file
|
||||
// remain at their zero value, allowing CLI flags to fill them in. Use
|
||||
// [Apply] to merge a parsed [File] onto an existing [session.Config];
|
||||
// non-zero fields in the session config (typically populated from CLI flags)
|
||||
// take precedence over the YAML values.
|
||||
// remain at their zero value. Use [Apply] to map a parsed [File] onto an
|
||||
// existing [session.Config]; non-zero fields in the session config take
|
||||
// precedence over the YAML values.
|
||||
//
|
||||
//nolint:tagliatelle // YAML keys are the documented config file schema.
|
||||
package config
|
||||
@@ -13,31 +12,68 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
|
||||
var ErrConfigNotFound = errors.New("config file not found")
|
||||
var (
|
||||
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
|
||||
ErrConfigNotFound = errors.New("config file not found")
|
||||
// ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured.
|
||||
ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set")
|
||||
// ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file.
|
||||
ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty")
|
||||
)
|
||||
|
||||
// File is the on-disk YAML schema.
|
||||
type File struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Gen Gen `yaml:"gen"`
|
||||
Data string `yaml:"data"`
|
||||
Debug bool `yaml:"debug"`
|
||||
FFmpeg string `yaml:"ffmpeg"`
|
||||
Mode string `yaml:"mode"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
Lifecycle Lifecycle `yaml:"lifecycle"`
|
||||
Traffic Traffic `yaml:"traffic"`
|
||||
Gen Gen `yaml:"gen"`
|
||||
Profiles []Profile `yaml:"profiles"`
|
||||
Failover Failover `yaml:"failover"`
|
||||
Data string `yaml:"data"`
|
||||
Debug bool `yaml:"debug"`
|
||||
FFmpeg string `yaml:"ffmpeg"`
|
||||
}
|
||||
|
||||
// Profile is a failover entry that overrides top-level runtime fields.
|
||||
type Profile struct {
|
||||
Name string `yaml:"name"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
Lifecycle Lifecycle `yaml:"lifecycle"`
|
||||
Traffic Traffic `yaml:"traffic"`
|
||||
}
|
||||
|
||||
// Failover controls ordered profile failover.
|
||||
type Failover struct {
|
||||
RetryDelay string `yaml:"retry_delay"`
|
||||
MaxCycles int `yaml:"max_cycles"`
|
||||
}
|
||||
|
||||
// Auth selects the auth provider.
|
||||
@@ -52,7 +88,8 @@ type Room struct {
|
||||
|
||||
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
|
||||
type Crypto struct {
|
||||
Key string `yaml:"key"` // 64-char hex (32 bytes)
|
||||
Key string `yaml:"key"` // 64-char hex (32 bytes)
|
||||
KeyFile string `yaml:"key_file"` // path to a file containing crypto.key
|
||||
}
|
||||
|
||||
// Net groups network and transport selection.
|
||||
@@ -106,6 +143,25 @@ type SEI struct {
|
||||
AckTimeoutMS int `yaml:"ack_timeout_ms"`
|
||||
}
|
||||
|
||||
// Liveness tunes the post-handshake control stream ping/pong checks.
|
||||
type Liveness struct {
|
||||
Interval string `yaml:"interval"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Failures int `yaml:"failures"`
|
||||
}
|
||||
|
||||
// Lifecycle controls planned session rebuilds.
|
||||
type Lifecycle struct {
|
||||
MaxSessionDuration string `yaml:"max_session_duration"`
|
||||
}
|
||||
|
||||
// Traffic controls optional reliability-oriented send shaping.
|
||||
type Traffic struct {
|
||||
MaxPayloadSize int `yaml:"max_payload_size"`
|
||||
MinDelay string `yaml:"min_delay"`
|
||||
MaxDelay string `yaml:"max_delay"`
|
||||
}
|
||||
|
||||
// Gen controls room-generation mode.
|
||||
type Gen struct {
|
||||
Amount int `yaml:"amount"`
|
||||
@@ -125,9 +181,63 @@ func Load(path string) (File, error) {
|
||||
if err := yaml.Unmarshal(data, &f); err != nil {
|
||||
return File{}, fmt.Errorf("parse config %s: %w", path, err)
|
||||
}
|
||||
if err := loadExternalSecrets(path, &f); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func loadExternalSecrets(configPath string, f *File) error {
|
||||
if f.Crypto.KeyFile == "" {
|
||||
return loadProfileSecrets(configPath, f.Profiles)
|
||||
}
|
||||
if f.Crypto.Key != "" {
|
||||
return ErrCryptoKeyConflict
|
||||
}
|
||||
|
||||
key, err := readKeyFile(configPath, f.Crypto.KeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Crypto.Key = key
|
||||
return loadProfileSecrets(configPath, f.Profiles)
|
||||
}
|
||||
|
||||
func loadProfileSecrets(configPath string, profiles []Profile) error {
|
||||
for i := range profiles {
|
||||
if profiles[i].Crypto.KeyFile == "" {
|
||||
continue
|
||||
}
|
||||
if profiles[i].Crypto.Key != "" {
|
||||
return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict)
|
||||
}
|
||||
key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profiles[%d]: %w", i, err)
|
||||
}
|
||||
profiles[i].Crypto.Key = key
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readKeyFile(configPath, keyFile string) (string, error) {
|
||||
keyPath := keyFile
|
||||
if !filepath.IsAbs(keyPath) {
|
||||
keyPath = filepath.Join(filepath.Dir(configPath), keyPath)
|
||||
}
|
||||
|
||||
// #nosec G304 -- key_file is an explicit path in the user's config file.
|
||||
data, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err)
|
||||
}
|
||||
key := strings.TrimSpace(string(data))
|
||||
if key == "" {
|
||||
return "", ErrCryptoKeyFileEmpty
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Apply merges f onto dst. CLI-set fields (non-zero values in dst) win;
|
||||
// YAML values fill in the rest.
|
||||
func Apply(dst session.Config, f File) session.Config {
|
||||
@@ -163,10 +273,61 @@ func Apply(dst session.Config, f File) session.Config {
|
||||
dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval)
|
||||
dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout)
|
||||
dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures)
|
||||
dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration)
|
||||
dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize)
|
||||
dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay)
|
||||
dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay)
|
||||
dst.Amount = pickInt(dst.Amount, f.Gen.Amount)
|
||||
return dst
|
||||
}
|
||||
|
||||
// ApplyProfile overlays a failover profile onto an already-applied base config.
|
||||
func ApplyProfile(base session.Config, p Profile) session.Config {
|
||||
dst := base
|
||||
dst.Link = overlayString(dst.Link, p.Link)
|
||||
dst.Transport = overlayString(dst.Transport, p.Net.Transport)
|
||||
dst.Auth = overlayString(dst.Auth, p.Auth.Provider)
|
||||
dst.Engine = overlayString(dst.Engine, p.Engine.Name)
|
||||
dst.URL = overlayString(dst.URL, p.Engine.URL)
|
||||
dst.Token = overlayString(dst.Token, p.Engine.Token)
|
||||
dst.RoomID = overlayString(dst.RoomID, p.Room.ID)
|
||||
dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key)
|
||||
dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host)
|
||||
dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port)
|
||||
dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User)
|
||||
dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass)
|
||||
dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS)
|
||||
dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr)
|
||||
dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort)
|
||||
dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width)
|
||||
dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height)
|
||||
dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS)
|
||||
dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate)
|
||||
dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW)
|
||||
dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize)
|
||||
dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery)
|
||||
dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec)
|
||||
dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule)
|
||||
dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS)
|
||||
dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS)
|
||||
dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize)
|
||||
dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS)
|
||||
dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval)
|
||||
dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout)
|
||||
dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures)
|
||||
dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration)
|
||||
dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize)
|
||||
dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay)
|
||||
dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay)
|
||||
return dst
|
||||
}
|
||||
|
||||
func pickString(cli, yamlVal string) string {
|
||||
if cli != "" {
|
||||
return cli
|
||||
@@ -180,3 +341,17 @@ func pickInt(cli, yamlVal int) int {
|
||||
}
|
||||
return yamlVal
|
||||
}
|
||||
|
||||
func overlayString(base, override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func overlayInt(base, override int) int {
|
||||
if override != 0 {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -38,6 +39,16 @@ socks:
|
||||
vp8:
|
||||
fps: 25
|
||||
batch_size: 4
|
||||
liveness:
|
||||
interval: 2s
|
||||
timeout: 500ms
|
||||
failures: 4
|
||||
lifecycle:
|
||||
max_session_duration: 6h
|
||||
traffic:
|
||||
max_payload_size: 4096
|
||||
min_delay: 5ms
|
||||
max_delay: 30ms
|
||||
gen:
|
||||
amount: 3
|
||||
debug: true
|
||||
@@ -75,20 +86,27 @@ func requireLoadedFile(t *testing.T, f File) {
|
||||
func requireAppliedConfig(t *testing.T, got session.Config) {
|
||||
t.Helper()
|
||||
want := session.Config{
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
Amount: 3,
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
LivenessInterval: "2s",
|
||||
LivenessTimeout: "500ms",
|
||||
LivenessFailures: 4,
|
||||
MaxSessionDuration: "6h",
|
||||
TrafficMaxPayloadSize: 4096,
|
||||
TrafficMinDelay: "5ms",
|
||||
TrafficMaxDelay: "30ms",
|
||||
Amount: 3,
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want)
|
||||
@@ -121,6 +139,182 @@ func TestApplyCLIWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAndApplyProfile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
mode: srv
|
||||
link: direct
|
||||
crypto:
|
||||
key: shared-key
|
||||
net:
|
||||
dns: 1.1.1.1:53
|
||||
liveness:
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
failures: 5
|
||||
lifecycle:
|
||||
max_session_duration: 6h
|
||||
traffic:
|
||||
max_payload_size: 8192
|
||||
min_delay: 10ms
|
||||
max_delay: 40ms
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: wb-room
|
||||
net:
|
||||
transport: vp8channel
|
||||
vp8:
|
||||
fps: 30
|
||||
liveness:
|
||||
interval: 1s
|
||||
lifecycle:
|
||||
max_session_duration: 30m
|
||||
traffic:
|
||||
max_payload_size: 4096
|
||||
max_delay: 20ms
|
||||
- name: jitsi-dc
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: https://meet.example/room
|
||||
net:
|
||||
transport: datachannel
|
||||
dns: 8.8.8.8:53
|
||||
failover:
|
||||
retry_delay: 100ms
|
||||
max_cycles: 2
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if len(f.Profiles) != 2 {
|
||||
t.Fatalf("profiles = %d, want 2", len(f.Profiles))
|
||||
}
|
||||
if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 {
|
||||
t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover)
|
||||
}
|
||||
|
||||
base := Apply(session.Config{}, f)
|
||||
first := ApplyProfile(base, f.Profiles[0])
|
||||
if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" {
|
||||
t.Fatalf("first profile = %+v", first)
|
||||
}
|
||||
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 ||
|
||||
first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 ||
|
||||
first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 ||
|
||||
first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" {
|
||||
t.Fatalf("first inherited/overlaid fields = %+v", first)
|
||||
}
|
||||
second := ApplyProfile(base, f.Profiles[1])
|
||||
if second.Auth != "jitsi" || second.Transport != "datachannel" ||
|
||||
second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" {
|
||||
t.Fatalf("second profile = %+v", second)
|
||||
}
|
||||
if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 ||
|
||||
second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 ||
|
||||
second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" {
|
||||
t.Fatalf("second lifecycle/liveness fields = %+v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProfileCryptoKeyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
profiles:
|
||||
- name: file-key
|
||||
crypto:
|
||||
key_file: profile.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if got := f.Profiles[0].Crypto.Key; got != testCryptoKey {
|
||||
t.Fatalf("profile key = %q, want %q", got, testCryptoKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
mode: srv
|
||||
crypto:
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if f.Crypto.Key != testCryptoKey {
|
||||
t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
crypto:
|
||||
key: deadbeef
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if !errors.Is(err, ErrCryptoKeyConflict) {
|
||||
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
crypto:
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if !errors.Is(err, ErrCryptoKeyFileEmpty) {
|
||||
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissing(t *testing.T) {
|
||||
_, err := Load(filepath.Join(t.TempDir(), "nope.yaml"))
|
||||
if err == nil {
|
||||
|
||||
343
internal/control/control.go
Normal file
343
internal/control/control.go
Normal file
@@ -0,0 +1,343 @@
|
||||
// Package control implements the post-handshake control stream protocol.
|
||||
//
|
||||
// The control stream is the first smux stream after the olcrtc handshake. It
|
||||
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
|
||||
// tunnel path still round-trips, not merely that the provider connection is up.
|
||||
//
|
||||
// Wire format matches the handshake framing: a 4-byte big-endian length
|
||||
// followed by a JSON message.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProtoVersion identifies the control stream wire format.
|
||||
ProtoVersion = 1
|
||||
// MaxMessageSize caps one control frame.
|
||||
MaxMessageSize = 16 * 1024
|
||||
// DefaultInterval is the default interval between ping probes.
|
||||
DefaultInterval = 10 * time.Second
|
||||
// DefaultTimeout is the default time to wait for a pong.
|
||||
DefaultTimeout = 5 * time.Second
|
||||
// DefaultFailures is the default number of consecutive missed pongs before
|
||||
// the stream is marked unhealthy.
|
||||
DefaultFailures = 3
|
||||
)
|
||||
|
||||
// MsgType labels a control message.
|
||||
type MsgType string
|
||||
|
||||
const (
|
||||
// TypePing is sent periodically to prove control-stream liveness.
|
||||
TypePing MsgType = "CONTROL_PING"
|
||||
// TypePong replies to a ping with the same sequence and timestamp.
|
||||
TypePong MsgType = "CONTROL_PONG"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnhealthy is returned when the stream misses too many pong replies.
|
||||
ErrUnhealthy = errors.New("control stream unhealthy")
|
||||
// ErrProtocolVersion is returned when the peer announces an incompatible version.
|
||||
ErrProtocolVersion = errors.New("incompatible control protocol version")
|
||||
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
|
||||
ErrUnexpectedMessage = errors.New("unexpected control message")
|
||||
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("control frame too large")
|
||||
)
|
||||
|
||||
// Message is one control-stream frame.
|
||||
type Message struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
Seq uint64 `json:"seq,omitempty"`
|
||||
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
|
||||
}
|
||||
|
||||
// Health is reported when a ping round trip completes.
|
||||
type Health struct {
|
||||
Seq uint64
|
||||
RTT time.Duration
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Status is a point-in-time view of control-stream health maintained by
|
||||
// callers that embed the control loop.
|
||||
type Status struct {
|
||||
SessionID string
|
||||
LastPong time.Time
|
||||
LastRTT time.Duration
|
||||
MissedPongs int
|
||||
Reconnects uint64
|
||||
UnhealthyEvents uint64
|
||||
LastUnhealthy time.Time
|
||||
}
|
||||
|
||||
// Config controls the liveness loop.
|
||||
type Config struct {
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
Failures int
|
||||
|
||||
// OnPong is called after a matching pong is received.
|
||||
OnPong func(Health)
|
||||
// OnMissedPong is called when one or more outstanding pongs time out.
|
||||
OnMissedPong func(missed int)
|
||||
// OnUnhealthy is called before Run returns [ErrUnhealthy].
|
||||
OnUnhealthy func(missed int)
|
||||
}
|
||||
|
||||
func (cfg Config) withDefaults() Config {
|
||||
if cfg.Interval <= 0 {
|
||||
cfg.Interval = DefaultInterval
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = DefaultTimeout
|
||||
}
|
||||
if cfg.Failures <= 0 {
|
||||
cfg.Failures = DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
|
||||
// or the configured failure threshold is reached.
|
||||
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
|
||||
cfg = cfg.withDefaults()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
state := &state{
|
||||
rw: rw,
|
||||
cfg: cfg,
|
||||
pending: make(map[uint64]time.Time),
|
||||
now: time.Now,
|
||||
out: make(chan Message, 16),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 3)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = rw.Close()
|
||||
}()
|
||||
go func() { errCh <- state.readLoop(ctx) }()
|
||||
go func() { errCh <- state.probeLoop(ctx) }()
|
||||
go func() { errCh <- state.writeLoop(ctx) }()
|
||||
|
||||
err := <-errCh
|
||||
cancel()
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type state struct {
|
||||
rw io.ReadWriteCloser
|
||||
cfg Config
|
||||
now func() time.Time
|
||||
|
||||
out chan Message
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[uint64]time.Time
|
||||
nextSeq uint64
|
||||
failures int
|
||||
}
|
||||
|
||||
func (s *state) readLoop(ctx context.Context) error {
|
||||
for {
|
||||
raw, err := readFrame(s.rw)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
msg, err := parseMessage(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Type {
|
||||
case TypePing:
|
||||
if err := s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePong,
|
||||
Seq: msg.Seq,
|
||||
SentUnixNano: msg.SentUnixNano,
|
||||
}); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
case TypePong:
|
||||
s.handlePong(msg)
|
||||
default:
|
||||
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) probeLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(s.cfg.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if err := s.sendProbe(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) sendProbe(ctx context.Context) error {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
missedNow := 0
|
||||
for seq, sent := range s.pending {
|
||||
if now.Sub(sent) < s.cfg.Timeout {
|
||||
continue
|
||||
}
|
||||
delete(s.pending, seq)
|
||||
s.failures++
|
||||
missedNow++
|
||||
}
|
||||
missed := s.failures
|
||||
if s.failures >= s.cfg.Failures {
|
||||
s.mu.Unlock()
|
||||
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
||||
s.cfg.OnMissedPong(missed)
|
||||
}
|
||||
if s.cfg.OnUnhealthy != nil {
|
||||
s.cfg.OnUnhealthy(missed)
|
||||
}
|
||||
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
|
||||
}
|
||||
|
||||
s.nextSeq++
|
||||
seq := s.nextSeq
|
||||
s.pending[seq] = now
|
||||
s.mu.Unlock()
|
||||
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
||||
s.cfg.OnMissedPong(missed)
|
||||
}
|
||||
|
||||
return s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePing,
|
||||
Seq: seq,
|
||||
SentUnixNano: now.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) handlePong(msg Message) {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
sent, ok := s.pending[msg.Seq]
|
||||
if ok {
|
||||
delete(s.pending, msg.Seq)
|
||||
s.failures = 0
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || s.cfg.OnPong == nil {
|
||||
return
|
||||
}
|
||||
s.cfg.OnPong(Health{
|
||||
Seq: msg.Seq,
|
||||
RTT: now.Sub(sent),
|
||||
LastSeen: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) enqueue(ctx context.Context, msg Message) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.out <- msg:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) writeLoop(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg := <-s.out:
|
||||
if err := writeFrame(s.rw, msg); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessage(raw []byte) (Message, error) {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
return Message{}, fmt.Errorf("parse control message: %w", err)
|
||||
}
|
||||
if msg.Version != ProtoVersion {
|
||||
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
|
||||
ErrProtocolVersion, msg.Version, ProtoVersion)
|
||||
}
|
||||
if msg.Type != TypePing && msg.Type != TypePong {
|
||||
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg Message) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal control message: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write control hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write control body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read control hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read control body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
138
internal/control/control_test.go
Normal file
138
internal/control/control_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func controlPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
a, b := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
return a, b
|
||||
}
|
||||
|
||||
func TestRunPingPongReportsRTT(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
got := make(chan Health, 1)
|
||||
cfg := Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- Run(ctx, a, cfg) }()
|
||||
go func() { errCh <- Run(ctx, b, cfg) }()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
if h.RTT < 0 {
|
||||
t.Fatalf("Health.RTT = %v", h.RTT)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for pong health")
|
||||
}
|
||||
|
||||
cancel()
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("Run() after cancel = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, b)
|
||||
}()
|
||||
|
||||
missedCh := make(chan int, 1)
|
||||
missedCallbackCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(ctx, a, Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnMissedPong: func(missed int) {
|
||||
select {
|
||||
case missedCallbackCh <- missed:
|
||||
default:
|
||||
}
|
||||
},
|
||||
OnUnhealthy: func(missed int) { missedCh <- missed },
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrUnhealthy) {
|
||||
t.Fatalf("Run() error = %v, want ErrUnhealthy", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for unhealthy result")
|
||||
}
|
||||
if missed := <-missedCh; missed < 2 {
|
||||
t.Fatalf("missed = %d, want >= 2", missed)
|
||||
}
|
||||
if missed := <-missedCallbackCh; missed < 1 {
|
||||
t.Fatalf("missed callback = %d, want >= 1", missed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsBadProtocolVersion(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
|
||||
}()
|
||||
if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil {
|
||||
t.Fatalf("writeFrame() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrProtocolVersion) {
|
||||
t.Fatalf("Run() error = %v, want ErrProtocolVersion", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsTooLarge(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
go func() {
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1)
|
||||
_, _ = b.Write(hdr[:])
|
||||
}()
|
||||
_, err := readFrame(a)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// WireOverhead is the number of bytes added to each encrypted message.
|
||||
const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead
|
||||
|
||||
var (
|
||||
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrInvalidKeySize = errors.New("invalid key size")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/server"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
@@ -47,6 +48,7 @@ var (
|
||||
errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply")
|
||||
errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting")
|
||||
errPayloadMismatchOffset = errors.New("payload mismatch at offset")
|
||||
errFailoverCarrier = errors.New("intentional failover carrier failure")
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) {
|
||||
})
|
||||
}
|
||||
|
||||
func registerFailingCarrier(t *testing.T) string {
|
||||
t.Helper()
|
||||
session.RegisterDefaults()
|
||||
|
||||
name := "e2e-fail-" + t.Name()
|
||||
carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) {
|
||||
return nil, errFailoverCarrier
|
||||
})
|
||||
return name
|
||||
}
|
||||
|
||||
func builtInCarrierNames() []string {
|
||||
return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional
|
||||
}
|
||||
@@ -1008,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
|
||||
if err := ln.Connect(context.Background()); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
assertLinkCanSendAfterConnect(t, ln, transportName)
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
@@ -1020,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) {
|
||||
t.Helper()
|
||||
|
||||
if transportName == transportSEI {
|
||||
if ln.CanSend() {
|
||||
t.Fatal("CanSend() = true before peer seichannel frame")
|
||||
}
|
||||
return
|
||||
}
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // table-driven test naturally has many branches
|
||||
func TestRealProviderTransportMatrix(t *testing.T) {
|
||||
if !*realE2E {
|
||||
@@ -1163,6 +1188,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) {
|
||||
echoAddr := startEchoServer(t)
|
||||
failingCarrier := registerFailingCarrier(t)
|
||||
memoryCarrier, room := registerMemoryCarrier(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
socksAddr := freeLocalAddr(ctx, t)
|
||||
socksHost, socksPort := splitHostPort(t, socksAddr)
|
||||
|
||||
serverProfiles := []supervisor.Profile{
|
||||
{Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)},
|
||||
{Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)},
|
||||
}
|
||||
clientProfiles := []supervisor.Profile{
|
||||
{Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)},
|
||||
{Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)},
|
||||
}
|
||||
|
||||
started := make(chan string, 8)
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run)
|
||||
}()
|
||||
room.waitConnected(t, 1)
|
||||
|
||||
ready := make(chan struct{})
|
||||
var readyOnce sync.Once
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error {
|
||||
return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() {
|
||||
if cfg.Auth == memoryCarrier {
|
||||
readyOnce.Do(func() { close(ready) })
|
||||
}
|
||||
})
|
||||
})
|
||||
}()
|
||||
|
||||
waitForReady(t, ready)
|
||||
conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
payload := []byte("olcrtc-failover-e2e\n")
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write failover payload: %v", err)
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||
t.Fatalf("set failover read deadline: %v", err)
|
||||
}
|
||||
line, err := bufio.NewReader(conn).ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read failover echo: %v", err)
|
||||
}
|
||||
if !bytes.Equal(line, payload) {
|
||||
t.Fatalf("failover echo = %q, want %q", line, payload)
|
||||
}
|
||||
|
||||
requireStartedProfiles(t, started, []string{
|
||||
"server:failing-server",
|
||||
"server:memory-server",
|
||||
"client:failing-client",
|
||||
"client:memory-client",
|
||||
})
|
||||
|
||||
cancel()
|
||||
waitSupervisorStopped(t, "client", clientErr)
|
||||
waitSupervisorStopped(t, "server", serverErr)
|
||||
}
|
||||
|
||||
func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config {
|
||||
cfg := session.Config{
|
||||
Mode: mode,
|
||||
Link: linkDirect,
|
||||
Transport: transportData,
|
||||
Auth: carrierName,
|
||||
RoomID: testRoom,
|
||||
KeyHex: testKeyHex,
|
||||
DNSServer: localDNSServer,
|
||||
}
|
||||
if mode == "cnc" {
|
||||
cfg.SOCKSHost = socksHost
|
||||
cfg.SOCKSPort = socksPort
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config {
|
||||
return client.Config{
|
||||
Link: cfg.Link,
|
||||
Transport: cfg.Transport,
|
||||
Carrier: cfg.Auth,
|
||||
RoomURL: cfg.RoomID,
|
||||
KeyHex: cfg.KeyHex,
|
||||
LocalAddr: socksAddr,
|
||||
DNSServer: cfg.DNSServer,
|
||||
DeviceID: testClientDeviceID,
|
||||
VideoWidth: cfg.VideoWidth,
|
||||
VideoHeight: cfg.VideoHeight,
|
||||
VideoFPS: cfg.VideoFPS,
|
||||
VideoBitrate: cfg.VideoBitrate,
|
||||
VideoHW: cfg.VideoHW,
|
||||
VideoQRSize: cfg.VideoQRSize,
|
||||
VideoQRRecovery: cfg.VideoQRRecovery,
|
||||
VideoCodec: cfg.VideoCodec,
|
||||
VideoTileModule: cfg.VideoTileModule,
|
||||
VideoTileRS: cfg.VideoTileRS,
|
||||
VP8FPS: cfg.VP8FPS,
|
||||
VP8BatchSize: cfg.VP8BatchSize,
|
||||
SEIFPS: cfg.SEIFPS,
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
}
|
||||
}
|
||||
|
||||
func failoverE2EConfig(
|
||||
profiles []supervisor.Profile,
|
||||
started chan<- string,
|
||||
side string,
|
||||
) supervisor.Config {
|
||||
return supervisor.Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: time.Millisecond,
|
||||
OnProfileStart: func(profile supervisor.Profile, _ int) {
|
||||
select {
|
||||
case started <- side + ":" + profile.Name:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func splitHostPort(t *testing.T, addr string) (string, int) {
|
||||
t.Helper()
|
||||
host, portText, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("split host port %q: %v", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
t.Fatalf("parse port %q: %v", portText, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func requireStartedProfiles(t *testing.T, started <-chan string, want []string) {
|
||||
t.Helper()
|
||||
seen := make(map[string]bool)
|
||||
deadline := time.After(3 * time.Second)
|
||||
for len(seen) < len(want) {
|
||||
select {
|
||||
case item := <-started:
|
||||
seen[item] = true
|
||||
case <-deadline:
|
||||
t.Fatalf("started profiles = %v, want all %v", seen, want)
|
||||
}
|
||||
}
|
||||
for _, item := range want {
|
||||
if !seen[item] {
|
||||
t.Fatalf("started profiles = %v, missing %s", seen, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) {
|
||||
t.Helper()
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("%s supervisor returned error: %v", name, err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("%s supervisor did not stop", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndedCallbackStopsClientAndServer(t *testing.T) {
|
||||
rt := startTunnel(t)
|
||||
rt.room.triggerEnded("conference ended")
|
||||
|
||||
@@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error {
|
||||
}
|
||||
|
||||
func (s *Session) dialWebSocket() error {
|
||||
wsDialer := websocket.Dialer{
|
||||
NetDialContext: protect.DialContext,
|
||||
HandshakeTimeout: wsHandshakeTimeout,
|
||||
}
|
||||
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
|
||||
ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial ws: %w", err)
|
||||
|
||||
@@ -19,13 +19,17 @@ import (
|
||||
protoLogger "github.com/livekit/protocol/logger"
|
||||
lksdk "github.com/livekit/server-sdk-go/v2"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/engine"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSendQueueSize = 5000
|
||||
dataPublishTopic = "olcrtc"
|
||||
videoTrackName = "videochannel"
|
||||
defaultSendQueueSize = 5000
|
||||
defaultSendQueueCapHard = 4000
|
||||
dataPublishTopic = "olcrtc"
|
||||
videoTrackName = "videochannel"
|
||||
reconnectWindow = 5 * time.Minute
|
||||
maxReconnects = 10
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,20 +45,98 @@ var (
|
||||
ErrTokenRequired = errors.New("livekit access token required")
|
||||
)
|
||||
|
||||
type roomHandle interface {
|
||||
publishData([]byte) error
|
||||
publishTrack(webrtc.TrackLocal) error
|
||||
unpublishLocalTracks()
|
||||
disconnect()
|
||||
connectionState() lksdk.ConnectionState
|
||||
}
|
||||
|
||||
type sdkRoom struct {
|
||||
room *lksdk.Room
|
||||
}
|
||||
|
||||
func (r *sdkRoom) publishData(data []byte) error {
|
||||
return r.room.LocalParticipant.PublishDataPacket(
|
||||
lksdk.UserData(data),
|
||||
lksdk.WithDataPublishTopic(dataPublishTopic),
|
||||
lksdk.WithDataPublishReliable(true),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error {
|
||||
_, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *sdkRoom) unpublishLocalTracks() {
|
||||
if r.room == nil || r.room.LocalParticipant == nil {
|
||||
return
|
||||
}
|
||||
for _, publication := range r.room.LocalParticipant.TrackPublications() {
|
||||
if publication.SID() == "" {
|
||||
continue
|
||||
}
|
||||
if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
|
||||
log.Printf("livekit unpublish track error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *sdkRoom) disconnect() {
|
||||
r.room.Disconnect()
|
||||
// LiveKit's Disconnect returns after local SDK teardown, before the
|
||||
// server necessarily evicts the participant. Give the signalling path a
|
||||
// short grace period so immediate reconnects do not inherit stale room
|
||||
// state from a ghost participant.
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
func (r *sdkRoom) connectionState() lksdk.ConnectionState {
|
||||
return r.room.ConnectionState()
|
||||
}
|
||||
|
||||
type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error)
|
||||
|
||||
func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) {
|
||||
room, err := lksdk.ConnectToRoomWithToken(
|
||||
url,
|
||||
token,
|
||||
callback,
|
||||
lksdk.WithAutoSubscribe(true),
|
||||
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sdkRoom{room: room}, nil
|
||||
}
|
||||
|
||||
// Session is the LiveKit engine handle.
|
||||
type Session struct {
|
||||
url string
|
||||
token string
|
||||
name string
|
||||
room *lksdk.Room
|
||||
refresh func(ctx context.Context) (engine.Credentials, error)
|
||||
connectRoom connectRoomFunc
|
||||
room roomHandle
|
||||
roomMu sync.RWMutex
|
||||
onData func([]byte)
|
||||
onReconnect func(*webrtc.DataChannel)
|
||||
shouldReconnect func() bool
|
||||
onEnded func(string)
|
||||
reconnectCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
lastReconnect time.Time
|
||||
reconnectCount int
|
||||
sendQueue chan []byte
|
||||
closed atomic.Bool
|
||||
reconnecting atomic.Bool
|
||||
done chan struct{}
|
||||
cancel context.CancelFunc
|
||||
shutdownOnce sync.Once
|
||||
sendWorkerOnce sync.Once
|
||||
videoTrackMu sync.RWMutex
|
||||
videoTracks []webrtc.TrackLocal
|
||||
onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
|
||||
@@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) {
|
||||
}
|
||||
_, cancel := context.WithCancel(ctx)
|
||||
return &Session{
|
||||
url: cfg.URL,
|
||||
token: cfg.Token,
|
||||
name: cfg.Name,
|
||||
onData: cfg.OnData,
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
url: cfg.URL,
|
||||
token: cfg.Token,
|
||||
name: cfg.Name,
|
||||
refresh: cfg.Refresh,
|
||||
connectRoom: connectSDKRoom,
|
||||
onData: cfg.OnData,
|
||||
reconnectCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities {
|
||||
}
|
||||
|
||||
// Connect joins the LiveKit room.
|
||||
func (s *Session) Connect(_ context.Context) error {
|
||||
func (s *Session) Connect(ctx context.Context) error {
|
||||
s.closed.Store(false)
|
||||
if err := s.connectSession(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
s.startSendWorker()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) connectSession(_ context.Context) error {
|
||||
roomCB := &lksdk.RoomCallback{
|
||||
ParticipantCallback: lksdk.ParticipantCallback{
|
||||
OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) {
|
||||
@@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error {
|
||||
},
|
||||
},
|
||||
OnDisconnected: func() {
|
||||
if !s.closed.Load() && s.onEnded != nil {
|
||||
s.onEnded("disconnected from livekit")
|
||||
if s.closed.Load() || s.reconnecting.Load() {
|
||||
return
|
||||
}
|
||||
if !s.queueReconnect() {
|
||||
s.signalEnded("disconnected from livekit")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
room, err := lksdk.ConnectToRoomWithToken(
|
||||
s.url,
|
||||
s.token,
|
||||
roomCB,
|
||||
lksdk.WithAutoSubscribe(true),
|
||||
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
|
||||
)
|
||||
room, err := s.connectRoom(s.url, s.token, roomCB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to room: %w", err)
|
||||
}
|
||||
|
||||
s.room = room
|
||||
s.setRoom(room)
|
||||
if err := s.publishPendingTracks(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.processSendQueue()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) publishPendingTracks() error {
|
||||
room := s.currentRoom()
|
||||
if room == nil {
|
||||
return ErrRoomNotConnected
|
||||
}
|
||||
s.videoTrackMu.RLock()
|
||||
defer s.videoTrackMu.RUnlock()
|
||||
for _, track := range s.videoTracks {
|
||||
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: videoTrackName,
|
||||
}); err != nil {
|
||||
if err := room.publishTrack(track); err != nil {
|
||||
return fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) startSendWorker() {
|
||||
s.sendWorkerOnce.Do(func() {
|
||||
s.wg.Add(1)
|
||||
go s.processSendQueue()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) processSendQueue() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
@@ -157,17 +256,33 @@ func (s *Session) processSendQueue() {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := s.room.LocalParticipant.PublishDataPacket(
|
||||
lksdk.UserData(data),
|
||||
lksdk.WithDataPublishTopic(dataPublishTopic),
|
||||
lksdk.WithDataPublishReliable(true),
|
||||
); err != nil {
|
||||
room := s.waitForConnectedRoom()
|
||||
if room == nil {
|
||||
return
|
||||
}
|
||||
if err := room.publishData(data); err != nil {
|
||||
log.Printf("livekit publish data error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) waitForConnectedRoom() roomHandle {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
room := s.currentRoom()
|
||||
if room != nil && room.connectionState() == lksdk.ConnectionStateConnected {
|
||||
return room
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queues data for transmission.
|
||||
func (s *Session) Send(data []byte) error {
|
||||
if s.closed.Load() {
|
||||
@@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error {
|
||||
|
||||
// Close terminates the session.
|
||||
func (s *Session) Close() error {
|
||||
if s.closed.CompareAndSwap(false, true) {
|
||||
s.cancel()
|
||||
close(s.done)
|
||||
if s.room != nil {
|
||||
s.unpublishLocalTracks()
|
||||
s.room.Disconnect()
|
||||
// LiveKit's Disconnect() returns once the local SDK state
|
||||
// is torn down, not when the server has actually evicted
|
||||
// the participant. Without giving the signalling channel
|
||||
// time to flush the LEAVE_REQUEST and the server to act on
|
||||
// it, a back-to-back reconnect from the same identity in
|
||||
// the same room sees a still-alive ghost participant on
|
||||
// the SFU and inherits stale publication state.
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
close(s.sendQueue)
|
||||
s.wg.Wait()
|
||||
}
|
||||
s.closed.Store(true)
|
||||
s.shutdown()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) unpublishLocalTracks() {
|
||||
if s.room == nil || s.room.LocalParticipant == nil {
|
||||
return
|
||||
}
|
||||
for _, publication := range s.room.LocalParticipant.TrackPublications() {
|
||||
if publication.SID() == "" {
|
||||
continue
|
||||
func (s *Session) shutdown() {
|
||||
s.shutdownOnce.Do(func() {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
|
||||
log.Printf("livekit unpublish track error: %v", err)
|
||||
closeSignal(s.closeCh)
|
||||
closeSignal(s.done)
|
||||
if room := s.swapRoom(nil); room != nil {
|
||||
room.unpublishLocalTracks()
|
||||
room.disconnect()
|
||||
}
|
||||
}
|
||||
s.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity).
|
||||
// SetReconnectCallback stores the reconnect callback.
|
||||
func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb }
|
||||
|
||||
// SetShouldReconnect stores the reconnect predicate (kept for API parity).
|
||||
// SetShouldReconnect stores the reconnect predicate.
|
||||
func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn }
|
||||
|
||||
// SetEndedCallback registers a function to call when the session ends.
|
||||
func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb }
|
||||
|
||||
// WatchConnection is a no-op; LiveKit handles connection supervision itself.
|
||||
func (s *Session) WatchConnection(_ context.Context) {}
|
||||
// WatchConnection monitors the connection lifecycle and reconnects as needed.
|
||||
func (s *Session) WatchConnection(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.closeCh:
|
||||
return
|
||||
case <-s.reconnectCh:
|
||||
if s.handleReconnectAttempt(ctx) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleReconnectAttempt(ctx context.Context) bool {
|
||||
if time.Since(s.lastReconnect) > reconnectWindow {
|
||||
s.reconnectCount = 0
|
||||
}
|
||||
s.reconnectCount++
|
||||
s.lastReconnect = time.Now()
|
||||
|
||||
if s.reconnectCount > maxReconnects {
|
||||
s.signalEnded("reconnect limit reached")
|
||||
return true
|
||||
}
|
||||
|
||||
backoff := time.Duration(s.reconnectCount) * 2 * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
if err := s.reconnect(ctx); err != nil {
|
||||
logger.Debugf("livekit reconnect failed: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
case <-s.closeCh:
|
||||
return true
|
||||
case <-time.After(backoff):
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.drainReconnectQueue()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) reconnect(ctx context.Context) error {
|
||||
s.reconnecting.Store(true)
|
||||
defer s.reconnecting.Store(false)
|
||||
|
||||
if room := s.swapRoom(nil); room != nil {
|
||||
room.unpublishLocalTracks()
|
||||
room.disconnect()
|
||||
}
|
||||
|
||||
if s.refresh != nil {
|
||||
creds, err := s.refresh(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("refresh credentials: %w", err)
|
||||
}
|
||||
s.applyRefreshedCredentials(creds)
|
||||
}
|
||||
|
||||
if err := s.connectSession(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.onReconnect != nil {
|
||||
s.onReconnect(nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) applyRefreshedCredentials(creds engine.Credentials) {
|
||||
if creds.URL != "" {
|
||||
s.url = creds.URL
|
||||
}
|
||||
if creds.Token != "" {
|
||||
s.token = creds.Token
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) queueReconnect() bool {
|
||||
if s.closed.Load() || s.reconnecting.Load() {
|
||||
return false
|
||||
}
|
||||
if s.shouldReconnect != nil && !s.shouldReconnect() {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Session) drainReconnectQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-s.reconnectCh:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) signalEnded(reason string) {
|
||||
s.closed.Store(true)
|
||||
s.shutdown()
|
||||
if s.onEnded != nil {
|
||||
s.onEnded(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// CanSend reports whether the session is ready to accept data.
|
||||
func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil }
|
||||
func (s *Session) CanSend() bool {
|
||||
if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard {
|
||||
return false
|
||||
}
|
||||
room := s.currentRoom()
|
||||
return room != nil && room.connectionState() == lksdk.ConnectionStateConnected
|
||||
}
|
||||
|
||||
// GetSendQueue exposes the outbound queue.
|
||||
func (s *Session) GetSendQueue() chan []byte { return s.sendQueue }
|
||||
@@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error {
|
||||
s.videoTracks = append(s.videoTracks, track)
|
||||
s.videoTrackMu.Unlock()
|
||||
|
||||
if s.room == nil || s.room.LocalParticipant == nil {
|
||||
room := s.currentRoom()
|
||||
if room == nil {
|
||||
return nil
|
||||
}
|
||||
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: videoTrackName,
|
||||
}); err != nil {
|
||||
if err := room.publishTrack(track); err != nil {
|
||||
return fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR
|
||||
s.onVideoTrack = cb
|
||||
}
|
||||
|
||||
func (s *Session) currentRoom() roomHandle {
|
||||
s.roomMu.RLock()
|
||||
defer s.roomMu.RUnlock()
|
||||
return s.room
|
||||
}
|
||||
|
||||
func (s *Session) setRoom(room roomHandle) {
|
||||
s.roomMu.Lock()
|
||||
defer s.roomMu.Unlock()
|
||||
s.room = room
|
||||
}
|
||||
|
||||
func (s *Session) swapRoom(room roomHandle) roomHandle {
|
||||
s.roomMu.Lock()
|
||||
defer s.roomMu.Unlock()
|
||||
old := s.room
|
||||
s.room = room
|
||||
return old
|
||||
}
|
||||
|
||||
func closeSignal(ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins
|
||||
engine.Register("livekit", New)
|
||||
}
|
||||
|
||||
306
internal/engine/livekit/livekit_test.go
Normal file
306
internal/engine/livekit/livekit_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package livekit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lksdk "github.com/livekit/server-sdk-go/v2"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/engine"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type fakeRoom struct {
|
||||
mu sync.Mutex
|
||||
state lksdk.ConnectionState
|
||||
published [][]byte
|
||||
tracks int
|
||||
unpublished int
|
||||
disconnected int
|
||||
}
|
||||
|
||||
func newFakeRoom() *fakeRoom {
|
||||
return &fakeRoom{state: lksdk.ConnectionStateConnected}
|
||||
}
|
||||
|
||||
func (r *fakeRoom) publishData(data []byte) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.published = append(r.published, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tracks++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRoom) unpublishLocalTracks() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.unpublished++
|
||||
}
|
||||
|
||||
func (r *fakeRoom) disconnect() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.disconnected++
|
||||
r.state = lksdk.ConnectionStateDisconnected
|
||||
}
|
||||
|
||||
func (r *fakeRoom) connectionState() lksdk.ConnectionState {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.state
|
||||
}
|
||||
|
||||
type fakeConnector struct {
|
||||
mu sync.Mutex
|
||||
urls []string
|
||||
tokens []string
|
||||
callbacks []*lksdk.RoomCallback
|
||||
rooms []*fakeRoom
|
||||
connected chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newFakeConnector() *fakeConnector {
|
||||
return &fakeConnector{connected: make(chan struct{}, 8)}
|
||||
}
|
||||
|
||||
func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
room := newFakeRoom()
|
||||
c.urls = append(c.urls, url)
|
||||
c.tokens = append(c.tokens, token)
|
||||
c.callbacks = append(c.callbacks, cb)
|
||||
c.rooms = append(c.rooms, room)
|
||||
c.connected <- struct{}{}
|
||||
return room, nil
|
||||
}
|
||||
|
||||
func (c *fakeConnector) count() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.rooms)
|
||||
}
|
||||
|
||||
func (c *fakeConnector) callback(i int) *lksdk.RoomCallback {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.callbacks[i]
|
||||
}
|
||||
|
||||
func (c *fakeConnector) room(i int) *fakeRoom {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.rooms[i]
|
||||
}
|
||||
|
||||
func (c *fakeConnector) snapshot() ([]string, []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return append([]string(nil), c.urls...), append([]string(nil), c.tokens...)
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, cond func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("condition was not met before timeout")
|
||||
}
|
||||
|
||||
func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
refreshes := 0
|
||||
sess, err := New(ctx, engine.Config{
|
||||
URL: "wss://old",
|
||||
Token: "old-token",
|
||||
Refresh: func(context.Context) (engine.Credentials, error) {
|
||||
refreshes++
|
||||
return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
s := sess.(*Session)
|
||||
connector := newFakeConnector()
|
||||
s.connectRoom = connector.connect
|
||||
|
||||
reconnected := make(chan struct{}, 1)
|
||||
s.SetReconnectCallback(func(*webrtc.DataChannel) {
|
||||
reconnected <- struct{}{}
|
||||
})
|
||||
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
go s.WatchConnection(ctx)
|
||||
|
||||
connector.callback(0).OnDisconnected()
|
||||
|
||||
waitFor(t, func() bool { return connector.count() == 2 })
|
||||
select {
|
||||
case <-reconnected:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reconnect callback was not called")
|
||||
}
|
||||
|
||||
urls, tokens := connector.snapshot()
|
||||
if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) {
|
||||
t.Fatalf("connect urls = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) {
|
||||
t.Fatalf("connect tokens = %v, want %v", got, want)
|
||||
}
|
||||
if refreshes != 1 {
|
||||
t.Fatalf("refreshes = %d, want 1", refreshes)
|
||||
}
|
||||
oldRoom := connector.room(0)
|
||||
oldRoom.mu.Lock()
|
||||
if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 {
|
||||
t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1",
|
||||
oldRoom.disconnected, oldRoom.unpublished)
|
||||
}
|
||||
oldRoom.mu.Unlock()
|
||||
if !s.CanSend() {
|
||||
t.Fatal("CanSend() = false after reconnect, want true")
|
||||
}
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
s := sess.(*Session)
|
||||
connector := newFakeConnector()
|
||||
s.connectRoom = connector.connect
|
||||
s.SetShouldReconnect(func() bool { return false })
|
||||
|
||||
ended := make(chan string, 1)
|
||||
s.SetEndedCallback(func(reason string) {
|
||||
ended <- reason
|
||||
})
|
||||
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
connector.callback(0).OnDisconnected()
|
||||
|
||||
select {
|
||||
case reason := <-ended:
|
||||
if reason != "disconnected from livekit" {
|
||||
t.Fatalf("ended reason = %q, want disconnected from livekit", reason)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ended callback was not called")
|
||||
}
|
||||
if !s.closed.Load() {
|
||||
t.Fatal("closed = false after terminal disconnect")
|
||||
}
|
||||
if connector.count() != 1 {
|
||||
t.Fatalf("connect count = %d, want 1", connector.count())
|
||||
}
|
||||
room := connector.room(0)
|
||||
room.mu.Lock()
|
||||
if room.disconnected != 1 || room.unpublished != 1 {
|
||||
t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1",
|
||||
room.disconnected, room.unpublished)
|
||||
}
|
||||
room.mu.Unlock()
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
room.mu.Lock()
|
||||
if room.disconnected != 1 || room.unpublished != 1 {
|
||||
t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1",
|
||||
room.disconnected, room.unpublished)
|
||||
}
|
||||
room.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) {
|
||||
s := &Session{
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true without room")
|
||||
}
|
||||
|
||||
room := newFakeRoom()
|
||||
room.state = lksdk.ConnectionStateDisconnected
|
||||
s.setRoom(room)
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true for disconnected room")
|
||||
}
|
||||
|
||||
room.state = lksdk.ConnectionStateConnected
|
||||
if !s.CanSend() {
|
||||
t.Fatal("CanSend() = false for connected room")
|
||||
}
|
||||
|
||||
for i := 0; i < defaultSendQueueCapHard; i++ {
|
||||
s.sendQueue <- []byte("x")
|
||||
}
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true at queue high watermark")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnectFailureRetriesUntilContextDone(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Session{
|
||||
url: "wss://old",
|
||||
token: "old-token",
|
||||
connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) {
|
||||
cancel()
|
||||
return nil, errors.New("boom")
|
||||
},
|
||||
reconnectCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
if terminal := s.handleReconnectAttempt(ctx); !terminal {
|
||||
t.Fatal("handleReconnectAttempt() = false after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration)
|
||||
}
|
||||
|
||||
func (s *Session) dialWebSocket() error {
|
||||
wsDialer := websocket.Dialer{
|
||||
NetDialContext: protect.DialContext,
|
||||
HandshakeTimeout: wsHandshakeTimeout,
|
||||
}
|
||||
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
|
||||
|
||||
ws, resp, err := wsDialer.Dial(s.connectorURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
// │ │
|
||||
//
|
||||
// After the exchange the control stream stays open; tunnel traffic flows over
|
||||
// additional smux streams opened by the client. The control stream may carry
|
||||
// keepalives or future control messages.
|
||||
// additional smux streams opened by the client. The control stream then
|
||||
// carries ping/pong liveness and future control messages.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package handshake
|
||||
|
||||
@@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create transport for direct link: %w", err)
|
||||
@@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) {
|
||||
d.transport.WatchConnection(ctx)
|
||||
}
|
||||
func (d *directLink) CanSend() bool { return d.transport.CanSend() }
|
||||
|
||||
// Features reports the direct link's underlying transport capabilities.
|
||||
func (d *directLink) Features() link.Features { return d.transport.Features() }
|
||||
|
||||
@@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
VideoTileRS: 20,
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 8,
|
||||
Traffic: transport.TrafficConfig{MaxPayloadSize: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
|
||||
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 ||
|
||||
seen.Traffic.MaxPayloadSize != 4096 {
|
||||
t.Fatalf("forwarded config = %+v", seen)
|
||||
}
|
||||
|
||||
@@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 {
|
||||
t.Fatalf("Features() = %+v, want shaped max payload 4096", features)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWrapsFactoryError(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,8 @@ package link
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,11 +25,19 @@ type Link interface {
|
||||
CanSend() bool
|
||||
}
|
||||
|
||||
// Features mirrors the underlying transport capabilities when a link can expose them.
|
||||
type Features = transport.Features
|
||||
|
||||
// FeaturesProvider is optionally implemented by links that can report wire limits.
|
||||
type FeaturesProvider interface {
|
||||
Features() Features
|
||||
}
|
||||
|
||||
// Config holds common link configuration.
|
||||
type Config struct {
|
||||
Transport string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
Transport string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
// Engine, URL, Token are forwarded for the "none" auth carrier.
|
||||
Engine string
|
||||
URL string
|
||||
@@ -54,6 +64,7 @@ type Config struct {
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Traffic transport.TrafficConfig
|
||||
}
|
||||
|
||||
// Factory creates a link instance.
|
||||
|
||||
@@ -3,13 +3,38 @@ package protect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialTimeout = 10 * time.Second
|
||||
defaultKeepAlive = 30 * time.Second
|
||||
defaultIdleConnTimeout = 30 * time.Second
|
||||
defaultTLSHandshake = 10 * time.Second
|
||||
defaultResponseHeader = 10 * time.Second
|
||||
defaultWebSocketTimeout = 10 * time.Second
|
||||
defaultHTTPClientTimeout = 30 * time.Second
|
||||
defaultStatusBodyLimit = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
sensitiveFieldRE = regexp.MustCompile(
|
||||
`(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` +
|
||||
`[^",\s}]+`,
|
||||
)
|
||||
sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`)
|
||||
) //nolint:gochecknoglobals // compiled once for provider error redaction
|
||||
|
||||
// Protector is called with a socket file descriptor before connect.
|
||||
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
|
||||
var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional
|
||||
@@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error {
|
||||
// NewDialer returns a net.Dialer that calls Protector on each new socket.
|
||||
func NewDialer() *net.Dialer {
|
||||
return &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Timeout: defaultDialTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
Control: controlFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients.
|
||||
func NewTLSConfig() *tls.Config {
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts.
|
||||
func NewHTTPTransport() *http.Transport {
|
||||
dialer := NewDialer()
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
TLSClientConfig: NewTLSConfig(),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: defaultIdleConnTimeout,
|
||||
TLSHandshakeTimeout: defaultTLSHandshake,
|
||||
ResponseHeaderTimeout: defaultResponseHeader,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPClient returns an http.Client using protected sockets.
|
||||
func NewHTTPClient() *http.Client {
|
||||
dialer := NewDialer()
|
||||
transport := &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
return &http.Client{
|
||||
Transport: NewHTTPTransport(),
|
||||
Timeout: defaultHTTPClientTimeout,
|
||||
}
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
|
||||
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
|
||||
if handshakeTimeout <= 0 {
|
||||
handshakeTimeout = defaultWebSocketTimeout
|
||||
}
|
||||
return websocket.Dialer{
|
||||
NetDialContext: DialContext,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: NewTLSConfig(),
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// StatusError formats an upstream HTTP error while bounding and redacting the body.
|
||||
func StatusError(base error, resp *http.Response, limit int64) error {
|
||||
if limit <= 0 {
|
||||
limit = defaultStatusBodyLimit
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||
bodyText := RedactSensitive(strings.TrimSpace(string(body)))
|
||||
if bodyText == "" {
|
||||
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
|
||||
}
|
||||
|
||||
// RedactSensitive removes common token-like values from provider error text.
|
||||
func RedactSensitive(text string) string {
|
||||
text = sensitiveBearerRE.ReplaceAllString(text, "${1}<redacted>")
|
||||
return sensitiveFieldRE.ReplaceAllString(text, "${1}<redacted>")
|
||||
}
|
||||
|
||||
// DialContext dials using a protected socket.
|
||||
|
||||
@@ -2,9 +2,11 @@ package protect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
|
||||
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
|
||||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
|
||||
tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second ||
|
||||
tr.ResponseHeaderTimeout != 10*time.Second {
|
||||
tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second {
|
||||
t.Fatalf("transport = %+v", tr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSocketDialer(t *testing.T) {
|
||||
dialer := NewWebSocketDialer(3 * time.Second)
|
||||
if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil ||
|
||||
dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 ||
|
||||
dialer.HandshakeTimeout != 3*time.Second {
|
||||
t.Fatalf("NewWebSocketDialer() = %+v", dialer)
|
||||
}
|
||||
|
||||
defaulted := NewWebSocketDialer(0)
|
||||
if defaulted.HandshakeTimeout != defaultWebSocketTimeout {
|
||||
t.Fatalf("default HandshakeTimeout = %v, want %v",
|
||||
defaulted.HandshakeTimeout, defaultWebSocketTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusErrorRedactsAndLimitsBody(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)},
|
||||
}
|
||||
err := StatusError(errProtectBoom, resp, 1024)
|
||||
if err == nil {
|
||||
t.Fatal("StatusError() error = nil")
|
||||
}
|
||||
text := err.Error()
|
||||
if strings.Contains(text, "secret") || !strings.Contains(text, "<redacted>") {
|
||||
t.Fatalf("StatusError() = %q, want redacted token", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactSensitiveBearer(t *testing.T) {
|
||||
got := RedactSensitive("Authorization: Bearer abc.def")
|
||||
if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer <redacted>") {
|
||||
t.Fatalf("RedactSensitive() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type ioNopCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (c ioNopCloser) Close() error { return nil }
|
||||
|
||||
func TestDialContextAndProxyDialer(t *testing.T) {
|
||||
var lc net.ListenConfig
|
||||
ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0")
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
@@ -49,25 +51,33 @@ type SessionCloseFunc func(sessionID, reason string)
|
||||
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
|
||||
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
|
||||
|
||||
// HealthFunc is called when the server control health snapshot changes.
|
||||
type HealthFunc func(control.Status)
|
||||
|
||||
// Server handles incoming tunnel connections and proxies their traffic.
|
||||
type Server struct {
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
authHook handshake.AuthFunc
|
||||
onOpen SessionOpenFunc
|
||||
onClose SessionCloseFunc
|
||||
onTraffic TrafficFunc
|
||||
onHealth HealthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
socksProxyPort int
|
||||
liveness control.Config
|
||||
health control.Status
|
||||
}
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
@@ -106,6 +116,8 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
Traffic transport.TrafficConfig
|
||||
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
@@ -117,6 +129,8 @@ type Config struct {
|
||||
OnSessionClose SessionCloseFunc
|
||||
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
|
||||
OnTraffic TrafficFunc
|
||||
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
|
||||
OnHealth HealthFunc
|
||||
}
|
||||
|
||||
// Run starts the server with the given configuration.
|
||||
@@ -145,6 +159,10 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
if onTraffic == nil {
|
||||
onTraffic = func(string, string, uint64, uint64) {}
|
||||
}
|
||||
onHealth := cfg.OnHealth
|
||||
if onHealth == nil {
|
||||
onHealth = func(control.Status) {}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
@@ -152,9 +170,11 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
onOpen: onOpen,
|
||||
onClose: onClose,
|
||||
onTraffic: onTraffic,
|
||||
onHealth: onHealth,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
liveness: cfg.Liveness,
|
||||
}
|
||||
s.setupResolver()
|
||||
|
||||
@@ -216,11 +236,17 @@ func (s *Server) setupResolver() {
|
||||
|
||||
// smuxConfig mirrors the client side. Both peers must agree on Version and
|
||||
// MaxFrameSize.
|
||||
func smuxConfig() *smux.Config {
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
@@ -228,6 +254,14 @@ func smuxConfig() *smux.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func linkMaxPayload(ln link.Link) int {
|
||||
provider, ok := ln.(link.FeaturesProvider)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return provider.Features().MaxPayloadSize
|
||||
}
|
||||
|
||||
func (s *Server) bringUpLink(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
@@ -262,6 +296,7 @@ func (s *Server) bringUpLink(
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
@@ -298,7 +333,7 @@ func (s *Server) bringUpLink(
|
||||
|
||||
func (s *Server) installSession() {
|
||||
conn := muxconn.New(s.ln, s.cipher)
|
||||
sess, err := smux.Server(conn, smuxConfig())
|
||||
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
return
|
||||
@@ -310,7 +345,8 @@ func (s *Server) installSession() {
|
||||
}
|
||||
|
||||
func (s *Server) handleReconnect() {
|
||||
logger.Infof("server link reconnect - tearing down smux session")
|
||||
s.recordReconnect()
|
||||
logger.Infof("server reconnect reason=carrier - tearing down smux session")
|
||||
s.sessMu.RLock()
|
||||
current := s.session
|
||||
s.sessMu.RUnlock()
|
||||
@@ -323,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
|
||||
// Pre-build the replacement so we can swap atomically below.
|
||||
newConn := muxconn.New(s.ln, s.cipher)
|
||||
newSess, err := smux.Server(newConn, smuxConfig())
|
||||
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
_ = newConn.Close()
|
||||
@@ -340,13 +376,18 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
}
|
||||
oldSess := s.session
|
||||
oldConn := s.conn
|
||||
oldControlStop := s.controlStop
|
||||
oldSID := s.sessionID
|
||||
s.session = newSess
|
||||
s.conn = newConn
|
||||
s.controlStop = nil
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldSess != nil {
|
||||
_ = oldSess.Close()
|
||||
}
|
||||
@@ -362,13 +403,18 @@ func (s *Server) closeSession() {
|
||||
s.sessMu.Lock()
|
||||
sess := s.session
|
||||
conn := s.conn
|
||||
controlStop := s.controlStop
|
||||
s.session = nil
|
||||
s.conn = nil
|
||||
s.controlStop = nil
|
||||
oldSID := s.sessionID
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -476,27 +522,120 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
s.deviceID = hello.DeviceID
|
||||
s.sessionID = sid
|
||||
s.sessMu.Unlock()
|
||||
s.recordSession(sid)
|
||||
s.onOpen(sid, hello.DeviceID, hello.Claims)
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.parkControlStream(stream)
|
||||
}()
|
||||
s.startControlLoop(ctx, sess, stream)
|
||||
return true
|
||||
}
|
||||
|
||||
// parkControlStream blocks reading from the control stream until it closes.
|
||||
// Future control messages (kick, rate updates, etc.) would be dispatched here.
|
||||
func (s *Server) parkControlStream(stream *smux.Stream) {
|
||||
defer func() { _ = stream.Close() }()
|
||||
buf := make([]byte, 64)
|
||||
for {
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
s.sessMu.Lock()
|
||||
s.controlStop = stop
|
||||
s.sessMu.Unlock()
|
||||
|
||||
liveness := s.liveness
|
||||
onPong := liveness.OnPong
|
||||
onMissedPong := liveness.OnMissedPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
s.sessMu.RLock()
|
||||
sid := s.sessionID
|
||||
s.sessMu.RUnlock()
|
||||
s.recordPong(h)
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnMissedPong = func(missed int) {
|
||||
s.recordMissed(missed)
|
||||
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
|
||||
if onMissedPong != nil {
|
||||
onMissedPong(missed)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
s.recordUnhealthy(missed)
|
||||
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer func() { _ = stream.Close() }()
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("server control stream ended: %v", err)
|
||||
}
|
||||
s.recordReconnect()
|
||||
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
|
||||
s.reinstallSession(sess)
|
||||
}()
|
||||
}
|
||||
|
||||
// Status returns the latest server-side control health snapshot.
|
||||
func (s *Server) Status() control.Status {
|
||||
s.healthMu.RLock()
|
||||
defer s.healthMu.RUnlock()
|
||||
return s.health
|
||||
}
|
||||
|
||||
func (s *Server) recordSession(sessionID string) {
|
||||
s.healthMu.Lock()
|
||||
s.health.SessionID = sessionID
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordPong(h control.Health) {
|
||||
s.healthMu.Lock()
|
||||
s.health.LastPong = h.LastSeen
|
||||
s.health.LastRTT = h.RTT
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordMissed(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordUnhealthy(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
s.health.UnhealthyEvents++
|
||||
s.health.LastUnhealthy = time.Now()
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordReconnect() {
|
||||
s.healthMu.Lock()
|
||||
s.health.Reconnects++
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) notifyHealth(status control.Status) {
|
||||
if s.onHealth != nil {
|
||||
s.onHealth(status)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -49,6 +50,11 @@ func TestSmuxConfig(t *testing.T) {
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
|
||||
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConnectRequest(t *testing.T) {
|
||||
@@ -373,6 +379,103 @@ func TestReinstallSessionFiresOnClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
serverStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
serverStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
clientStream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
serverStream := <-serverStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
got := make(chan control.Health, 1)
|
||||
s := &Server{
|
||||
sessionID: "sid-control",
|
||||
liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
s.recordSession("sid-control")
|
||||
defer func() {
|
||||
cancel()
|
||||
s.wg.Wait()
|
||||
}()
|
||||
s.startControlLoop(ctx, serverSess, serverStream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, clientStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
status := s.Status()
|
||||
if status.SessionID != "sid-control" {
|
||||
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
|
||||
}
|
||||
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
s := &Server{onHealth: func(control.Status) { updates++ }}
|
||||
s.recordSession("sid-1")
|
||||
s.recordMissed(2)
|
||||
s.recordUnhealthy(3)
|
||||
s.recordReconnect()
|
||||
|
||||
status := s.Status()
|
||||
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
|
||||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
if updates != 4 {
|
||||
t.Fatalf("health updates = %d, want 4", updates)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together.
|
||||
func TestDispatchFiresOnTraffic(t *testing.T) {
|
||||
var lc net.ListenConfig
|
||||
|
||||
229
internal/supervisor/supervisor.go
Normal file
229
internal/supervisor/supervisor.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Package supervisor runs ordered session profiles with failover.
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
)
|
||||
|
||||
const DefaultRetryDelay = 2 * time.Second
|
||||
const DefaultHistoryLimit = 20
|
||||
|
||||
const (
|
||||
// EventProfileStart marks a profile attempt starting.
|
||||
EventProfileStart = "profile_start"
|
||||
// EventProfileEnd marks a profile attempt ending.
|
||||
EventProfileEnd = "profile_end"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoProfiles is returned when the supervisor is started without profiles.
|
||||
ErrNoProfiles = errors.New("supervisor: no profiles configured")
|
||||
// ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes.
|
||||
ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded")
|
||||
)
|
||||
|
||||
// Profile is one runnable session configuration in an ordered failover list.
|
||||
type Profile struct {
|
||||
Name string
|
||||
Config session.Config
|
||||
}
|
||||
|
||||
// ProfileStatus summarizes one profile's failover history.
|
||||
type ProfileStatus struct {
|
||||
Name string
|
||||
Starts int
|
||||
Failures int
|
||||
CleanEnds int
|
||||
LastStarted time.Time
|
||||
LastEnded time.Time
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Event is one bounded failover history entry.
|
||||
type Event struct {
|
||||
Time time.Time
|
||||
Type string
|
||||
Profile string
|
||||
Cycle int
|
||||
Error string
|
||||
}
|
||||
|
||||
// Status is a point-in-time view of the supervisor.
|
||||
type Status struct {
|
||||
Cycle int
|
||||
ActiveProfile string
|
||||
ActiveProfileIndex int
|
||||
Profiles []ProfileStatus
|
||||
History []Event
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Runner starts one session profile and blocks until it ends or fails.
|
||||
type Runner func(ctx context.Context, cfg session.Config) error
|
||||
|
||||
// Config controls ordered failover behavior.
|
||||
type Config struct {
|
||||
Profiles []Profile
|
||||
RetryDelay time.Duration
|
||||
MaxCycles int
|
||||
|
||||
OnProfileStart func(profile Profile, cycle int)
|
||||
OnProfileEnd func(profile Profile, cycle int, err error)
|
||||
OnStatus func(status Status)
|
||||
HistoryLimit int
|
||||
}
|
||||
|
||||
// Run starts profiles in order. If a profile exits while ctx is still active,
|
||||
// the supervisor waits RetryDelay and advances to the next profile.
|
||||
func Run(ctx context.Context, cfg Config, run Runner) error {
|
||||
if len(cfg.Profiles) == 0 {
|
||||
return ErrNoProfiles
|
||||
}
|
||||
if cfg.RetryDelay == 0 {
|
||||
cfg.RetryDelay = DefaultRetryDelay
|
||||
}
|
||||
state := newStatusTracker(cfg.Profiles, cfg.HistoryLimit, cfg.OnStatus)
|
||||
|
||||
var lastErr error
|
||||
for cycle := 1; ; cycle++ {
|
||||
for i, profile := range cfg.Profiles {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
state.start(i, cycle)
|
||||
if cfg.OnProfileStart != nil {
|
||||
cfg.OnProfileStart(profile, cycle)
|
||||
}
|
||||
|
||||
err := run(ctx, profile.Config)
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("profile %q: %w", profile.Name, err)
|
||||
} else {
|
||||
lastErr = fmt.Errorf("profile %q ended", profile.Name)
|
||||
}
|
||||
state.end(i, cycle, err)
|
||||
if cfg.OnProfileEnd != nil {
|
||||
cfg.OnProfileEnd(profile, cycle, err)
|
||||
}
|
||||
|
||||
if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 {
|
||||
return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr)
|
||||
}
|
||||
if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type statusTracker struct {
|
||||
status Status
|
||||
notify func(Status)
|
||||
historyLimit int
|
||||
}
|
||||
|
||||
func newStatusTracker(profiles []Profile, historyLimit int, notify func(Status)) *statusTracker {
|
||||
if historyLimit == 0 {
|
||||
historyLimit = DefaultHistoryLimit
|
||||
}
|
||||
statusProfiles := make([]ProfileStatus, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
statusProfiles = append(statusProfiles, ProfileStatus{Name: profile.Name})
|
||||
}
|
||||
return &statusTracker{
|
||||
status: Status{
|
||||
ActiveProfileIndex: -1,
|
||||
Profiles: statusProfiles,
|
||||
},
|
||||
notify: notify,
|
||||
historyLimit: historyLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *statusTracker) start(profileIndex, cycle int) {
|
||||
now := time.Now()
|
||||
profile := &t.status.Profiles[profileIndex]
|
||||
profile.Starts++
|
||||
profile.LastStarted = now
|
||||
t.status.Cycle = cycle
|
||||
t.status.ActiveProfile = profile.Name
|
||||
t.status.ActiveProfileIndex = profileIndex
|
||||
t.appendHistory(Event{
|
||||
Time: now,
|
||||
Type: EventProfileStart,
|
||||
Profile: profile.Name,
|
||||
Cycle: cycle,
|
||||
})
|
||||
t.emit()
|
||||
}
|
||||
|
||||
func (t *statusTracker) end(profileIndex, cycle int, err error) {
|
||||
now := time.Now()
|
||||
profile := &t.status.Profiles[profileIndex]
|
||||
profile.LastEnded = now
|
||||
event := Event{
|
||||
Time: now,
|
||||
Type: EventProfileEnd,
|
||||
Profile: profile.Name,
|
||||
Cycle: cycle,
|
||||
}
|
||||
if err != nil {
|
||||
profile.Failures++
|
||||
profile.LastError = err.Error()
|
||||
t.status.LastError = fmt.Sprintf("profile %q: %v", profile.Name, err)
|
||||
event.Error = err.Error()
|
||||
} else {
|
||||
profile.CleanEnds++
|
||||
profile.LastError = ""
|
||||
t.status.LastError = fmt.Sprintf("profile %q ended", profile.Name)
|
||||
}
|
||||
t.status.ActiveProfile = ""
|
||||
t.status.ActiveProfileIndex = -1
|
||||
t.appendHistory(event)
|
||||
t.emit()
|
||||
}
|
||||
|
||||
func (t *statusTracker) appendHistory(event Event) {
|
||||
if t.historyLimit < 0 {
|
||||
return
|
||||
}
|
||||
t.status.History = append(t.status.History, event)
|
||||
if len(t.status.History) > t.historyLimit {
|
||||
t.status.History = t.status.History[len(t.status.History)-t.historyLimit:]
|
||||
}
|
||||
}
|
||||
|
||||
func (t *statusTracker) emit() {
|
||||
if t.notify == nil {
|
||||
return
|
||||
}
|
||||
t.notify(cloneStatus(t.status))
|
||||
}
|
||||
|
||||
func cloneStatus(status Status) Status {
|
||||
status.Profiles = append([]ProfileStatus(nil), status.Profiles...)
|
||||
status.History = append([]Event(nil), status.History...)
|
||||
return status
|
||||
}
|
||||
|
||||
func waitRetryDelay(ctx context.Context, delay time.Duration) error {
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
170
internal/supervisor/supervisor_test.go
Normal file
170
internal/supervisor/supervisor_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
)
|
||||
|
||||
var errRunnerBoom = errors.New("boom")
|
||||
|
||||
func TestRunRequiresProfiles(t *testing.T) {
|
||||
err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil })
|
||||
if !errors.Is(err, ErrNoProfiles) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) {
|
||||
profiles := []Profile{
|
||||
{Name: "first", Config: session.Config{Auth: "wbstream"}},
|
||||
{Name: "second", Config: session.Config{Auth: "jitsi"}},
|
||||
}
|
||||
var started []string
|
||||
var ended []string
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
OnProfileStart: func(profile Profile, cycle int) {
|
||||
started = append(started, profile.Name)
|
||||
if cycle != 1 {
|
||||
t.Fatalf("cycle = %d, want 1", cycle)
|
||||
}
|
||||
},
|
||||
OnProfileEnd: func(profile Profile, _ int, err error) {
|
||||
ended = append(ended, profile.Name)
|
||||
if !errors.Is(err, errRunnerBoom) {
|
||||
t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom)
|
||||
}
|
||||
},
|
||||
}, func(_ context.Context, cfg session.Config) error {
|
||||
if cfg.Auth == "" {
|
||||
t.Fatal("runner received empty auth")
|
||||
}
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if got, want := started, []string{"first", "second"}; !equalStrings(got, want) {
|
||||
t.Fatalf("started = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) {
|
||||
t.Fatalf("ended = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmitsStatusHistory(t *testing.T) {
|
||||
profiles := []Profile{
|
||||
{Name: "first", Config: session.Config{Auth: "wbstream"}},
|
||||
{Name: "second", Config: session.Config{Auth: "jitsi"}},
|
||||
}
|
||||
var snapshots []Status
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
HistoryLimit: 3,
|
||||
OnStatus: func(status Status) {
|
||||
snapshots = append(snapshots, status)
|
||||
},
|
||||
}, func(_ context.Context, cfg session.Config) error {
|
||||
if cfg.Auth == "first" {
|
||||
t.Fatal("runner received profile name instead of config")
|
||||
}
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if len(snapshots) != 4 {
|
||||
t.Fatalf("status snapshots = %d, want 4", len(snapshots))
|
||||
}
|
||||
first := snapshots[0]
|
||||
if first.ActiveProfile != "first" || first.ActiveProfileIndex != 0 || first.Cycle != 1 {
|
||||
t.Fatalf("first status = %+v", first)
|
||||
}
|
||||
if first.Profiles[0].Starts != 1 || first.Profiles[0].LastStarted.IsZero() {
|
||||
t.Fatalf("first profile start status = %+v", first.Profiles[0])
|
||||
}
|
||||
last := snapshots[len(snapshots)-1]
|
||||
if last.ActiveProfile != "" || last.ActiveProfileIndex != -1 {
|
||||
t.Fatalf("last active status = %+v", last)
|
||||
}
|
||||
if last.Profiles[0].Failures != 1 || last.Profiles[1].Failures != 1 {
|
||||
t.Fatalf("profile failures = %+v", last.Profiles)
|
||||
}
|
||||
if last.LastError == "" || last.Profiles[1].LastError == "" {
|
||||
t.Fatalf("last errors missing: %+v", last)
|
||||
}
|
||||
if len(last.History) != 3 {
|
||||
t.Fatalf("history length = %d, want 3", len(last.History))
|
||||
}
|
||||
if last.History[0].Type != EventProfileEnd || last.History[0].Profile != "first" {
|
||||
t.Fatalf("oldest bounded history event = %+v", last.History[0])
|
||||
}
|
||||
if last.History[2].Type != EventProfileEnd || last.History[2].Profile != "second" ||
|
||||
last.History[2].Error == "" {
|
||||
t.Fatalf("last history event = %+v", last.History[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStatusSnapshotIsImmutable(t *testing.T) {
|
||||
var first Status
|
||||
var second Status
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: []Profile{{Name: "one"}},
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
OnStatus: func(status Status) {
|
||||
if first.Profiles == nil {
|
||||
first = status
|
||||
first.Profiles[0].Starts = 99
|
||||
first.History[0].Profile = "mutated"
|
||||
return
|
||||
}
|
||||
second = status
|
||||
},
|
||||
}, func(context.Context, session.Config) error {
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if first.Profiles[0].Starts != 99 || first.History[0].Profile != "mutated" {
|
||||
t.Fatalf("test mutation did not apply to snapshot: %+v", first)
|
||||
}
|
||||
if second.Profiles[0].Starts != 1 || second.History[0].Profile != "one" {
|
||||
t.Fatalf("snapshot mutation leaked into later status: %+v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunReturnsNilOnContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
err := Run(ctx, Config{
|
||||
Profiles: []Profile{{Name: "one"}},
|
||||
RetryDelay: time.Hour,
|
||||
}, func(context.Context, session.Config) error {
|
||||
cancel()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
protocolVersion byte = 1
|
||||
frameTypeData byte = 1
|
||||
frameTypeAck byte = 2
|
||||
frameTypeHello byte = 3
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -86,6 +87,7 @@ type streamTransport struct {
|
||||
nextSeq atomic.Uint32
|
||||
closed atomic.Bool
|
||||
writerUp atomic.Bool
|
||||
peerReady atomic.Bool
|
||||
sendMu sync.Mutex
|
||||
startWriter sync.Once
|
||||
ackMu sync.Mutex
|
||||
@@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) {
|
||||
|
||||
// CanSend reports whether transport is ready for sending.
|
||||
func (p *streamTransport) CanSend() bool {
|
||||
return !p.closed.Load() && p.stream.CanSend()
|
||||
return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend()
|
||||
}
|
||||
|
||||
// Features describes the current seichannel transport semantics.
|
||||
@@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() {
|
||||
ticker := time.NewTicker(p.effectiveFrameInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
idle := buildVideoAccessUnit(nil)
|
||||
idle := buildVideoAccessUnit(encodeHelloFrame())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) {
|
||||
}
|
||||
|
||||
switch frame.typ {
|
||||
case frameTypeHello:
|
||||
p.peerReady.Store(true)
|
||||
case frameTypeAck:
|
||||
p.peerReady.Store(true)
|
||||
p.resolveAck(frame.seq, frame.crc)
|
||||
case frameTypeData:
|
||||
p.peerReady.Store(true)
|
||||
p.handleInboundFrame(frame)
|
||||
}
|
||||
}
|
||||
@@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
|
||||
out[5] = frameTypeData
|
||||
binary.BigEndian.PutUint32(out[6:10], seq)
|
||||
binary.BigEndian.PutUint32(out[10:14], crc)
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
copy(out[22:], payload)
|
||||
return out
|
||||
@@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeHelloFrame() []byte {
|
||||
out := make([]byte, 6)
|
||||
binary.BigEndian.PutUint32(out[0:4], protocolMagic)
|
||||
out[4] = protocolVersion
|
||||
out[5] = frameTypeHello
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
if len(data) < 6 {
|
||||
return transportFrame{}, ErrFrameTooShort
|
||||
@@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
|
||||
frame := transportFrame{typ: data[5]}
|
||||
switch frame.typ {
|
||||
case frameTypeHello:
|
||||
return frame, nil
|
||||
case frameTypeAck:
|
||||
if len(data) < 14 {
|
||||
return transportFrame{}, ErrAckTooShort
|
||||
|
||||
@@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) {
|
||||
t.Fatalf("payload mismatch: got=%q", decoded.payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloFrameRoundTrip(t *testing.T) {
|
||||
hello, err := decodeTransportFrame(encodeHelloFrame())
|
||||
if err != nil {
|
||||
t.Fatalf("decodeTransportFrame(hello) failed: %v", err)
|
||||
}
|
||||
if hello.typ != frameTypeHello {
|
||||
t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) {
|
||||
if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched {
|
||||
t.Fatal("callbacks/watch were not forwarded")
|
||||
}
|
||||
if tr.CanSend() {
|
||||
t.Fatal("CanSend() = true before peer hello")
|
||||
}
|
||||
tr.handleSample(buildVideoAccessUnit(encodeHelloFrame()))
|
||||
if !tr.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
t.Fatal("CanSend() = false after peer hello")
|
||||
}
|
||||
if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description
|
||||
t.Fatalf("Features() = %+v", features)
|
||||
|
||||
91
internal/transport/traffic.go
Normal file
91
internal/transport/traffic.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size")
|
||||
|
||||
type trafficTransport struct {
|
||||
inner Transport
|
||||
maxPayloadSize int
|
||||
minDelay time.Duration
|
||||
maxDelay time.Duration
|
||||
sendMu sync.Mutex
|
||||
}
|
||||
|
||||
// WithTraffic wraps tr with optional payload caps and send pacing.
|
||||
func WithTraffic(tr Transport, cfg TrafficConfig) Transport {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
cfg = effectiveTrafficConfig(tr.Features(), cfg)
|
||||
if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 {
|
||||
return tr
|
||||
}
|
||||
return &trafficTransport{
|
||||
inner: tr,
|
||||
maxPayloadSize: cfg.MaxPayloadSize,
|
||||
minDelay: cfg.MinDelay,
|
||||
maxDelay: cfg.MaxDelay,
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig {
|
||||
if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize {
|
||||
cfg.MaxPayloadSize = features.MaxPayloadSize
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) }
|
||||
|
||||
func (t *trafficTransport) Send(data []byte) error {
|
||||
t.sendMu.Lock()
|
||||
defer t.sendMu.Unlock()
|
||||
if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize {
|
||||
return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize)
|
||||
}
|
||||
if delay := t.nextDelay(); delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
return t.inner.Send(data)
|
||||
}
|
||||
|
||||
func (t *trafficTransport) Close() error { return t.inner.Close() }
|
||||
|
||||
func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) }
|
||||
|
||||
func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) }
|
||||
|
||||
func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) }
|
||||
|
||||
func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) }
|
||||
|
||||
func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() }
|
||||
|
||||
func (t *trafficTransport) Features() Features {
|
||||
features := t.inner.Features()
|
||||
if t.maxPayloadSize > 0 &&
|
||||
(features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) {
|
||||
features.MaxPayloadSize = t.maxPayloadSize
|
||||
}
|
||||
return features
|
||||
}
|
||||
|
||||
func (t *trafficTransport) nextDelay() time.Duration {
|
||||
if t.maxDelay <= 0 && t.minDelay <= 0 {
|
||||
return 0
|
||||
}
|
||||
minDelay := t.minDelay
|
||||
maxDelay := t.maxDelay
|
||||
if maxDelay <= minDelay {
|
||||
return minDelay
|
||||
}
|
||||
return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter
|
||||
}
|
||||
67
internal/transport/traffic_test.go
Normal file
67
internal/transport/traffic_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type trafficStubTransport struct {
|
||||
features Features
|
||||
sent [][]byte
|
||||
}
|
||||
|
||||
func (s *trafficStubTransport) Connect(context.Context) error { return nil }
|
||||
func (s *trafficStubTransport) Send(data []byte) error {
|
||||
s.sent = append(s.sent, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
func (s *trafficStubTransport) Close() error { return nil }
|
||||
func (s *trafficStubTransport) SetReconnectCallback(func()) {}
|
||||
func (s *trafficStubTransport) SetShouldReconnect(func() bool) {}
|
||||
func (s *trafficStubTransport) SetEndedCallback(func(string)) {}
|
||||
func (s *trafficStubTransport) WatchConnection(context.Context) {}
|
||||
func (s *trafficStubTransport) CanSend() bool { return true }
|
||||
func (s *trafficStubTransport) Features() Features { return s.features }
|
||||
|
||||
func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) {
|
||||
inner := &trafficStubTransport{}
|
||||
got := WithTraffic(inner, TrafficConfig{})
|
||||
if got != inner {
|
||||
t.Fatalf("WithTraffic disabled returned %T, want inner", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) {
|
||||
inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}}
|
||||
tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10})
|
||||
if features := tr.Features(); features.MaxPayloadSize != 5 {
|
||||
t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize)
|
||||
}
|
||||
err := tr.Send([]byte("123456"))
|
||||
if !errors.Is(err, ErrTrafficPayloadTooLarge) {
|
||||
t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge)
|
||||
}
|
||||
if len(inner.sent) != 0 {
|
||||
t.Fatalf("inner sent %d payloads, want 0", len(inner.sent))
|
||||
}
|
||||
if err := tr.Send([]byte("12345")); err != nil {
|
||||
t.Fatalf("Send(max sized) error = %v", err)
|
||||
}
|
||||
if got := string(inner.sent[0]); got != "12345" {
|
||||
t.Fatalf("inner payload = %q, want 12345", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) {
|
||||
inner := &trafficStubTransport{}
|
||||
tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond})
|
||||
start := time.Now()
|
||||
if err := tr.Send([]byte("x")); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed < 2*time.Millisecond {
|
||||
t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package transport
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -32,10 +33,17 @@ type Transport interface {
|
||||
Features() Features
|
||||
}
|
||||
|
||||
// TrafficConfig controls optional reliability-oriented send shaping.
|
||||
type TrafficConfig struct {
|
||||
MaxPayloadSize int
|
||||
MinDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
}
|
||||
|
||||
// Config holds common transport configuration.
|
||||
type Config struct {
|
||||
Carrier string
|
||||
RoomURL string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
// Engine, URL, Token are forwarded to carrier.Config for the "none" auth
|
||||
// carrier (direct engine access without a service-specific auth flow).
|
||||
Engine string
|
||||
@@ -63,6 +71,7 @@ type Config struct {
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Traffic TrafficConfig
|
||||
}
|
||||
|
||||
// Factory creates a transport instance.
|
||||
@@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) {
|
||||
if !ok {
|
||||
return nil, ErrTransportNotFound
|
||||
}
|
||||
return factory(ctx, cfg)
|
||||
tr, err := factory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return WithTraffic(tr, cfg.Traffic), nil
|
||||
}
|
||||
|
||||
// Available returns a list of registered transport names.
|
||||
|
||||
101
mobile/mobile.go
101
mobile/mobile.go
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
|
||||
@@ -65,23 +66,26 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional
|
||||
defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional
|
||||
defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional
|
||||
registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional
|
||||
mu sync.Mutex //nolint:gochecknoglobals // package-level state intentional
|
||||
defaults mobileConfig //nolint:gochecknoglobals // package-level state intentional
|
||||
defaultsSet sync.Once //nolint:gochecknoglobals // package-level state intentional
|
||||
registerSet sync.Once //nolint:gochecknoglobals // package-level state intentional
|
||||
runClientWithReady = client.RunWithReady //nolint:gochecknoglobals // package-level state intentional
|
||||
cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional
|
||||
done chan struct{} //nolint:gochecknoglobals // package-level state intentional
|
||||
ready chan struct{} //nolint:gochecknoglobals // package-level state intentional
|
||||
cancel context.CancelFunc //nolint:gochecknoglobals // package-level state intentional
|
||||
done chan struct{} //nolint:gochecknoglobals // package-level state intentional
|
||||
ready chan struct{} //nolint:gochecknoglobals // package-level state intentional
|
||||
errRun error
|
||||
)
|
||||
|
||||
type mobileConfig struct {
|
||||
link string
|
||||
transport string
|
||||
dnsServer string
|
||||
vp8FPS int
|
||||
vp8BatchSize int
|
||||
link string
|
||||
transport string
|
||||
dnsServer string
|
||||
vp8FPS int
|
||||
vp8BatchSize int
|
||||
livenessInterval time.Duration
|
||||
livenessTimeout time.Duration
|
||||
livenessFailures int
|
||||
}
|
||||
|
||||
// SetProtector sets the Android VPN socket protector.
|
||||
@@ -143,6 +147,21 @@ func SetVP8Options(fps, batchSize int) {
|
||||
defaults.vp8BatchSize = clampAtLeastOne(batchSize, 64)
|
||||
}
|
||||
|
||||
// SetLivenessOptions configures control-stream ping/pong checks.
|
||||
// Values <= 0 reset that field to its default. Durations are milliseconds.
|
||||
func SetLivenessOptions(intervalMillis, timeoutMillis, failures int) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
ensureDefaultConfigLocked()
|
||||
defaults.livenessInterval = durationFromMillisOrDefault(intervalMillis, control.DefaultInterval)
|
||||
defaults.livenessTimeout = durationFromMillisOrDefault(timeoutMillis, control.DefaultTimeout)
|
||||
if failures <= 0 {
|
||||
defaults.livenessFailures = control.DefaultFailures
|
||||
return
|
||||
}
|
||||
defaults.livenessFailures = failures
|
||||
}
|
||||
|
||||
// SetDebug enables or disables verbose logging.
|
||||
func SetDebug(enabled bool) {
|
||||
logger.SetVerbose(enabled)
|
||||
@@ -195,6 +214,11 @@ func Check(
|
||||
vp8BatchSize int,
|
||||
) (int64, error) {
|
||||
registerDefaults()
|
||||
mu.Lock()
|
||||
ensureDefaultConfigLocked()
|
||||
cfg := defaults
|
||||
mu.Unlock()
|
||||
|
||||
carrierName = normalizeCarrier(carrierName)
|
||||
transportName = normalizeTransport(transportName)
|
||||
if err := validateStartArgs(carrierName, roomID, clientID, keyHex); err != nil {
|
||||
@@ -227,6 +251,7 @@ func Check(
|
||||
DNSServer: defaultDNSServer,
|
||||
VP8FPS: clampAtLeastOne(vp8FPS, 120),
|
||||
VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64),
|
||||
Liveness: livenessConfig(cfg),
|
||||
},
|
||||
func() {
|
||||
readyOnce.Do(func() {
|
||||
@@ -271,6 +296,11 @@ func Ping(
|
||||
vp8BatchSize int,
|
||||
) (int64, error) {
|
||||
registerDefaults()
|
||||
mu.Lock()
|
||||
ensureDefaultConfigLocked()
|
||||
cfg := defaults
|
||||
mu.Unlock()
|
||||
|
||||
carrierName = normalizeCarrier(carrierName)
|
||||
transportName = normalizeTransport(transportName)
|
||||
|
||||
@@ -310,6 +340,7 @@ func Ping(
|
||||
DNSServer: defaultDNSServer,
|
||||
VP8FPS: clampAtLeastOne(vp8FPS, 120),
|
||||
VP8BatchSize: clampAtLeastOne(vp8BatchSize, 64),
|
||||
Liveness: livenessConfig(cfg),
|
||||
},
|
||||
func() {
|
||||
readyOnce.Do(func() {
|
||||
@@ -557,6 +588,7 @@ func startWithConfig(
|
||||
SOCKSPass: socksPass,
|
||||
VP8FPS: cfg.vp8FPS,
|
||||
VP8BatchSize: cfg.vp8BatchSize,
|
||||
Liveness: livenessConfig(cfg),
|
||||
},
|
||||
func() {
|
||||
readyOnce.Do(func() {
|
||||
@@ -576,6 +608,7 @@ func startWithConfig(
|
||||
}
|
||||
|
||||
// WaitReady blocks until the selected transport is connected and the local SOCKS5 listener is ready.
|
||||
//
|
||||
//nolint:cyclop // straightforward state-machine waits with multiple terminal conditions
|
||||
func WaitReady(timeoutMillis int) error {
|
||||
mu.Lock()
|
||||
@@ -666,15 +699,38 @@ func waitForCheckDone(doneCh <-chan error) {
|
||||
func ensureDefaultConfigLocked() {
|
||||
defaultsSet.Do(func() {
|
||||
defaults = mobileConfig{
|
||||
link: defaultLink,
|
||||
transport: defaultTransport,
|
||||
dnsServer: defaultDNSServer,
|
||||
vp8FPS: 60,
|
||||
vp8BatchSize: 8,
|
||||
link: defaultLink,
|
||||
transport: defaultTransport,
|
||||
dnsServer: defaultDNSServer,
|
||||
vp8FPS: 60,
|
||||
vp8BatchSize: 8,
|
||||
livenessInterval: control.DefaultInterval,
|
||||
livenessTimeout: control.DefaultTimeout,
|
||||
livenessFailures: control.DefaultFailures,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func livenessConfig(cfg mobileConfig) control.Config {
|
||||
interval := cfg.livenessInterval
|
||||
if interval <= 0 {
|
||||
interval = control.DefaultInterval
|
||||
}
|
||||
timeout := cfg.livenessTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = control.DefaultTimeout
|
||||
}
|
||||
failures := cfg.livenessFailures
|
||||
if failures <= 0 {
|
||||
failures = control.DefaultFailures
|
||||
}
|
||||
return control.Config{
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
Failures: failures,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeTransport(value string) string {
|
||||
switch value {
|
||||
case dataTransport, "data", "dc":
|
||||
@@ -734,6 +790,17 @@ func clampAtLeastOne(value, maxValue int) int {
|
||||
return value
|
||||
}
|
||||
|
||||
func durationFromMillisOrDefault(value int, def time.Duration) time.Duration {
|
||||
if value <= 0 {
|
||||
return def
|
||||
}
|
||||
d := time.Duration(value) * time.Millisecond
|
||||
if d <= 0 {
|
||||
return def
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// logBridge adapts LogWriter to io.Writer.
|
||||
type logBridge struct {
|
||||
w LogWriter
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
)
|
||||
@@ -83,12 +84,15 @@ func TestDefaultsAndSetters(t *testing.T) {
|
||||
SetLink("direct")
|
||||
SetDNS("9.9.9.9:53")
|
||||
SetVP8Options(-1, 999)
|
||||
SetLivenessOptions(2500, 750, -1)
|
||||
|
||||
mu.Lock()
|
||||
got := defaults
|
||||
mu.Unlock()
|
||||
if got.transport != dataTransport || got.link != defaultLink || got.dnsServer != "9.9.9.9:53" ||
|
||||
got.vp8FPS != 1 || got.vp8BatchSize != 64 {
|
||||
got.vp8FPS != 1 || got.vp8BatchSize != 64 ||
|
||||
got.livenessInterval != 2500*time.Millisecond || got.livenessTimeout != 750*time.Millisecond ||
|
||||
got.livenessFailures != control.DefaultFailures {
|
||||
t.Fatalf("defaults = %+v", got)
|
||||
}
|
||||
|
||||
@@ -168,15 +172,19 @@ func TestStartWithInjectedRunnerLifecycle(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
resetMobileGlobals(t)
|
||||
})
|
||||
SetLivenessOptions(2500, 750, 4)
|
||||
|
||||
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
|
||||
if cfg.Link != defaultLink || cfg.Transport != dataTransport || cfg.Carrier != carrierJazz ||
|
||||
cfg.RoomURL != "any" || cfg.DeviceID != "client" || cfg.LocalAddr != "127.0.0.1:1080" ||
|
||||
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 {
|
||||
cfg.DNSServer != defaultDNSServer || cfg.VP8FPS != 60 || cfg.VP8BatchSize != 8 ||
|
||||
cfg.Liveness.Interval != 2500*time.Millisecond ||
|
||||
cfg.Liveness.Timeout != 750*time.Millisecond ||
|
||||
cfg.Liveness.Failures != 4 {
|
||||
t.Fatalf(
|
||||
"RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d",
|
||||
"RunWithReady args mismatch: link=%q transport=%q carrier=%q room=%q client=%q local=%q dns=%q vp8=%d/%d liveness=%+v",
|
||||
cfg.Link, cfg.Transport, cfg.Carrier, cfg.RoomURL, cfg.DeviceID,
|
||||
cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize,
|
||||
cfg.LocalAddr, cfg.DNSServer, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness,
|
||||
)
|
||||
}
|
||||
onReady()
|
||||
@@ -208,9 +216,12 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
|
||||
|
||||
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
|
||||
if cfg.Transport != defaultTransport || cfg.RoomURL != "https://telemost.yandex.ru/j/room" ||
|
||||
cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" {
|
||||
t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q",
|
||||
cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass)
|
||||
cfg.LocalAddr != "127.0.0.1:1081" || cfg.SOCKSUser != "u" || cfg.SOCKSPass != "p" ||
|
||||
cfg.Liveness.Interval != control.DefaultInterval ||
|
||||
cfg.Liveness.Timeout != control.DefaultTimeout ||
|
||||
cfg.Liveness.Failures != control.DefaultFailures {
|
||||
t.Fatalf("Start args mismatch: transport=%q room=%q local=%q user/pass=%q/%q liveness=%+v",
|
||||
cfg.Transport, cfg.RoomURL, cfg.LocalAddr, cfg.SOCKSUser, cfg.SOCKSPass, cfg.Liveness)
|
||||
}
|
||||
onReady()
|
||||
<-ctx.Done()
|
||||
@@ -225,9 +236,14 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
|
||||
}
|
||||
Stop()
|
||||
|
||||
SetLivenessOptions(3000, 1000, 5)
|
||||
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
|
||||
if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 {
|
||||
t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d", cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize)
|
||||
if cfg.Transport != dataTransport || cfg.VP8FPS != 1 || cfg.VP8BatchSize != 64 ||
|
||||
cfg.Liveness.Interval != 3000*time.Millisecond ||
|
||||
cfg.Liveness.Timeout != time.Second ||
|
||||
cfg.Liveness.Failures != 5 {
|
||||
t.Fatalf("Check args mismatch: transport=%q vp8=%d/%d liveness=%+v",
|
||||
cfg.Transport, cfg.VP8FPS, cfg.VP8BatchSize, cfg.Liveness)
|
||||
}
|
||||
onReady()
|
||||
<-ctx.Done()
|
||||
@@ -242,6 +258,32 @@ func TestStartUsesDefaultsAndCheckWithInjectedRunner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingPassesLiveness(t *testing.T) {
|
||||
resetMobileGlobals(t)
|
||||
t.Cleanup(func() {
|
||||
resetMobileGlobals(t)
|
||||
})
|
||||
SetLivenessOptions(4000, 1500, 6)
|
||||
|
||||
seen := make(chan control.Config, 1)
|
||||
runClientWithReady = func(ctx context.Context, cfg client.Config, onReady func()) error {
|
||||
seen <- cfg.Liveness
|
||||
onReady()
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
_, _ = Ping("jazz", "dc", "", "client", "key", 1085, 100, "http://127.0.0.1/", 30, 1)
|
||||
select {
|
||||
case got := <-seen:
|
||||
if got.Interval != 4000*time.Millisecond || got.Timeout != 1500*time.Millisecond || got.Failures != 6 {
|
||||
t.Fatalf("Ping liveness = %+v", got)
|
||||
}
|
||||
default:
|
||||
t.Fatal("Ping did not start client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTimeoutAndRunError(t *testing.T) {
|
||||
resetMobileGlobals(t)
|
||||
t.Cleanup(func() {
|
||||
|
||||
Reference in New Issue
Block a user