From b0fc3bd0f1ff668ab52d5654102ca8139cc1c59e Mon Sep 17 00:00:00 2001 From: cyber-debug Date: Sat, 16 May 2026 00:25:24 +0300 Subject: [PATCH] feat: add control stream liveness --- cmd/olcrtc/main.go | 3 +- docs/client.example.yaml | 5 + docs/configuration.md | 22 ++ docs/failover.example.yaml | 5 + docs/project-map.md | 19 +- docs/server.example.yaml | 5 + docs/settings.md | 10 +- internal/app/session/session.go | 150 ++++++++++--- internal/app/session/session_test.go | 47 ++++ internal/client/client.go | 91 +++++++- internal/client/client_test.go | 68 ++++++ internal/config/config.go | 37 ++- internal/config/config_test.go | 47 ++-- internal/control/control.go | 321 +++++++++++++++++++++++++++ internal/control/control_test.go | 128 +++++++++++ internal/handshake/handshake.go | 4 +- internal/server/server.go | 67 ++++-- internal/server/server_test.go | 72 ++++++ 18 files changed, 1012 insertions(+), 89 deletions(-) create mode 100644 internal/control/control.go create mode 100644 internal/control/control_test.go diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 777949b..af7b87f 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -140,6 +140,7 @@ func runWithConfig(cfg loadedConfig) error { return fmt.Errorf("validate config: %w", err) } scfg = session.ApplyTransportDefaults(scfg) + scfg = session.ApplyLivenessDefaults(scfg) if scfg.Mode == modeGen { if len(cfg.profiles) > 0 { @@ -166,7 +167,7 @@ func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error if err != nil { return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err) } - profile.Config = session.ApplyTransportDefaults(scfg) + profile.Config = session.ApplyLivenessDefaults(session.ApplyTransportDefaults(scfg)) out = append(out, profile) } return out, nil diff --git a/docs/client.example.yaml b/docs/client.example.yaml index fe83e0d..a074a6a 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -21,6 +21,11 @@ net: transport: datachannel # must match the server dns: "8.8.8.8:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + # Local SOCKS5 listener exposed to applications socks: host: "127.0.0.1" diff --git a/docs/configuration.md b/docs/configuration.md index 46edd07..8c067ad 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,6 +31,9 @@ olcrtc /etc/olcrtc/server.yaml | `video.*` | videochannel tuning | | `vp8.*` | vp8channel tuning | | `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning | +| `liveness.interval` | control-stream ping interval, default `10s` | +| `liveness.timeout` | pong timeout, default `5s` | +| `liveness.failures` | missed pongs before reconnect, default `3` | | `gen.amount` | gen mode: number of rooms to create | | `profiles[]` | ordered srv/cnc failover profiles | | `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | @@ -45,6 +48,25 @@ olcrtc /etc/olcrtc/server.yaml `crypto.key_file` is resolved relative to the YAML file. Do not set it together with `crypto.key`. +## Liveness + +After `CLIENT_HELLO` / `SERVER_WELCOME`, the first smux stream stays open as +an encrypted control stream. olcrtc now sends `CONTROL_PING` / `CONTROL_PONG` +messages over that stream to prove the real tunnel path still round-trips. +This detects states where a provider or WebRTC layer looks connected but the +encrypted smux path is no longer usable. + +```yaml +liveness: + interval: 10s + timeout: 5s + failures: 3 +``` + +When the failure threshold is reached, the current smux session is rebuilt. +In failover mode, a profile that exits after liveness-triggered reconnect +failure lets the supervisor advance to the next profile. + ## Failover Profiles `mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml index 7aa8149..e956a35 100644 --- a/docs/failover.example.yaml +++ b/docs/failover.example.yaml @@ -10,6 +10,11 @@ crypto: net: dns: "1.1.1.1:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + data: data profiles: diff --git a/docs/project-map.md b/docs/project-map.md index c4c8791..e1b2134 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -72,6 +72,7 @@ Important fields: | `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. | | `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | | `engine.*` | direct engine fields | Used only with `auth.provider: none`. | +| `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. | `internal/app/session` is the main router: @@ -151,6 +152,18 @@ SERVER_REJECT { version, reason } The handshake has a 64 KiB frame cap and a default 15 second timeout. +After handshake, `internal/control` keeps that same encrypted smux stream open +and exchanges length-prefixed JSON control messages: + +```text +CONTROL_PING { version, seq, sent_unix_nano } +CONTROL_PONG { version, seq, sent_unix_nano } +``` + +Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and +`liveness.failures: 3`. Missed pongs mark the smux session unhealthy and +trigger a session rebuild/reconnect path. + ## Registries And Plugin Shape The universal-carrier refactor centers on small registries: @@ -320,9 +333,9 @@ adaptive instead of static YAML knobs. ### 3. Control Stream Protocol -The first smux stream is parked after handshake. It is the natural place for: +The first smux stream now carries control ping/pong after handshake. It is +still the natural place for: -- Ping/pong and peer liveness. - Server policy updates. - Graceful reconnect notifications. - Drain/start markers for failover. @@ -330,7 +343,7 @@ The first smux stream is parked after handshake. It is the natural place for: Likely files: -- `internal/handshake` +- `internal/control` - `internal/server` - `internal/client` diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 9f5ee38..c20b1e5 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -23,6 +23,11 @@ net: transport: datachannel # datachannel | videochannel | seichannel | vp8channel dns: "8.8.8.8:53" +liveness: + interval: 10s + timeout: 5s + failures: 3 + # Outbound SOCKS5 proxy for server-side egress (optional) socks: proxy_addr: "" # e.g. "127.0.0.1" diff --git a/docs/settings.md b/docs/settings.md index 28855ce..2e2d78a 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -63,12 +63,20 @@ | `profiles` | Список профилей failover для `srv`/`cnc` | | `failover.retry_delay` | Пауза перед следующим профилем, например `2s` | | `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно | +| `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` | +| `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` | +| `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` | `crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. Если задан `profiles`, поля верхнего уровня становятся общими defaults, а каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и -настройки транспорта. Порядок профилей должен совпадать на сервере и клиенте. +настройки транспорта/liveness. Порядок профилей должен совпадать на сервере и +клиенте. + +`liveness` проверяет именно зашифрованный smux control stream после handshake, +а не только статус WebRTC/provider соединения. Если pong не приходит несколько +раз подряд, текущая smux-сессия пересоздается. --- diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 89de5f5..360d96a 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -13,6 +13,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/link/direct" "github.com/openlibrecommunity/olcrtc/internal/logger" @@ -120,43 +121,56 @@ var ( // ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication. ErrSOCKSAuthRequired = errors.New( "socks auth required when binding outside loopback (set socks.user and socks.pass)") + + // ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration. + ErrLivenessIntervalInvalid = errors.New( + "invalid liveness interval (set liveness.interval to a duration > 0)") + // ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration. + ErrLivenessTimeoutInvalid = errors.New( + "invalid liveness timeout (set liveness.timeout to a duration > 0)") + // ErrLivenessFailuresInvalid indicates that liveness.failures is not positive. + ErrLivenessFailuresInvalid = errors.New( + "invalid liveness failures (set liveness.failures to a value > 0)") ) // Config holds runtime session settings. type Config struct { - Mode string - Link string - Transport string - Auth string - Engine string - URL string - Token string - RoomID string - KeyHex string - SOCKSHost string - SOCKSPort int - SOCKSUser string - SOCKSPass string - DNSServer string - SOCKSProxyAddr string - SOCKSProxyPort int - VideoWidth int - VideoHeight int - VideoFPS int - VideoBitrate string - VideoHW string - VideoQRSize int - VideoQRRecovery string - VideoCodec string - VideoTileModule int - VideoTileRS int - VP8FPS int - VP8BatchSize int - SEIFPS int - SEIBatchSize int - SEIFragmentSize int - SEIAckTimeoutMS int - Amount int + Mode string + Link string + Transport string + Auth string + Engine string + URL string + Token string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + SOCKSUser string + SOCKSPass string + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + LivenessInterval string + LivenessTimeout string + LivenessFailures int + Amount int } // RegisterDefaults registers built-in carriers and transports. @@ -212,6 +226,20 @@ func ApplyTransportDefaults(cfg Config) Config { } } +// ApplyLivenessDefaults fills documented control-stream liveness defaults. +func ApplyLivenessDefaults(cfg Config) Config { + if cfg.LivenessInterval == "" { + cfg.LivenessInterval = control.DefaultInterval.String() + } + if cfg.LivenessTimeout == "" { + cfg.LivenessTimeout = control.DefaultTimeout.String() + } + if cfg.LivenessFailures == 0 { + cfg.LivenessFailures = control.DefaultFailures + } + return cfg +} + func applyVideoDefaults(cfg Config) Config { if cfg.VideoCodec == "" { cfg.VideoCodec = videoCodecQRCode @@ -292,6 +320,9 @@ func Validate(cfg Config) error { if err := validateTransportConfig(cfg); err != nil { return err } + if err := validateLivenessConfig(cfg); err != nil { + return err + } return validateModeConfig(cfg) } @@ -431,6 +462,52 @@ func validateModeConfig(cfg Config) error { return nil } +func validateLivenessConfig(cfg Config) error { + if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil { + return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err) + } + if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil { + return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err) + } + if cfg.LivenessFailures < 0 { + return ErrLivenessFailuresInvalid + } + return nil +} + +func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) { + if value == "" { + return def, nil + } + d, err := time.ParseDuration(value) + if err != nil { + return 0, err + } + if d <= 0 { + return 0, fmt.Errorf("duration must be > 0") + } + return d, nil +} + +func livenessConfig(cfg Config) (control.Config, error) { + interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval) + if err != nil { + return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err) + } + timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout) + if err != nil { + return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err) + } + failures := cfg.LivenessFailures + if failures == 0 { + failures = control.DefaultFailures + } + if failures < 0 { + return control.Config{}, ErrLivenessFailuresInvalid + } + return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil +} + func isLoopbackListenHost(host string) bool { if host == "localhost" { return true @@ -442,7 +519,12 @@ func isLoopbackListenHost(host string) bool { // Run starts the configured mode. func Run(ctx context.Context, cfg Config) error { cfg = ApplyTransportDefaults(cfg) + cfg = ApplyLivenessDefaults(cfg) roomURL := cfg.RoomID + liveness, err := livenessConfig(cfg) + if err != nil { + return err + } switch cfg.Mode { case modeSRV: @@ -474,6 +556,7 @@ func Run(ctx context.Context, cfg Config) error { Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, + Liveness: liveness, OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) { logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims) }, @@ -517,6 +600,7 @@ func Run(ctx context.Context, cfg Config) error { Engine: cfg.Engine, URL: cfg.URL, Token: cfg.Token, + Liveness: liveness, }); err != nil { return fmt.Errorf("client: %w", err) } diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index f20e70d..95270b2 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "testing" + + "github.com/openlibrecommunity/olcrtc/internal/control" ) func TestApplyTransportDefaults(t *testing.T) { @@ -85,6 +87,24 @@ func TestApplyTransportDefaults(t *testing.T) { } } +func TestApplyLivenessDefaults(t *testing.T) { + got := ApplyLivenessDefaults(Config{}) + if got.LivenessInterval != control.DefaultInterval.String() { + t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String()) + } + if got.LivenessTimeout != control.DefaultTimeout.String() { + t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String()) + } + if got.LivenessFailures != control.DefaultFailures { + t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures) + } + + explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9} + if got := ApplyLivenessDefaults(explicit); got != explicit { + t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit) + } +} + //nolint:maintidx // table-driven validation test naturally has many cases func TestValidate(t *testing.T) { RegisterDefaults() @@ -422,6 +442,33 @@ func TestValidate(t *testing.T) { return cfg }(), }, + { + name: "liveness rejects bad interval", + cfg: func() Config { + cfg := base + cfg.LivenessInterval = "nope" + return cfg + }(), + want: ErrLivenessIntervalInvalid, + }, + { + name: "liveness rejects zero timeout", + cfg: func() Config { + cfg := base + cfg.LivenessTimeout = "0s" + return cfg + }(), + want: ErrLivenessTimeoutInvalid, + }, + { + name: "liveness rejects negative failures", + cfg: func() Config { + cfg := base + cfg.LivenessFailures = -1 + return cfg + }(), + want: ErrLivenessFailuresInvalid, + }, } for _, tt := range tests { diff --git a/internal/client/client.go b/internal/client/client.go index 0d73bd9..13be135 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -17,6 +17,7 @@ import ( "time" "github.com/google/uuid" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" @@ -54,7 +55,9 @@ type Client struct { conn *muxconn.Conn session *smux.Session controlStrm *smux.Stream + controlStop context.CancelFunc sessMu sync.RWMutex + reconnectMu sync.Mutex deviceID string sessionID string claims map[string]any @@ -93,6 +96,7 @@ type Config struct { Engine string URL string Token string + Liveness control.Config // DeviceID overrides the persistent client-side device identifier. Leave // empty to derive one from DeviceIDPath (or generate a random one if both @@ -217,7 +221,9 @@ func (c *Client) bringUpLink( if ctx.Err() != nil { return } - c.handleReconnect() + if !c.handleReconnect(ctx, cfg, cancel) { + cancel() + } }) if err := ln.Connect(ctx); err != nil { @@ -243,14 +249,15 @@ func (c *Client) bringUpLink( c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.startControlLoop(ctx, cfg, cancel, control) go ln.WatchConnection(ctx) return nil } // openControlStream opens stream #1 on sess and performs the handshake. -// The stream stays open for the lifetime of the smux session — the server -// holds it parked, and it would carry future control messages. +// The stream stays open for the lifetime of the smux session and carries +// post-handshake control messages. func openControlStream( sess *smux.Session, deviceID string, @@ -326,7 +333,10 @@ func smuxConfig() *smux.Config { return cfg } -func (c *Client) handleReconnect() { +func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc) bool { + c.reconnectMu.Lock() + defer c.reconnectMu.Unlock() + logger.Infof("client link reconnect - tearing down smux session") // Install a fresh muxconn immediately so onData never hits nil while @@ -336,14 +346,19 @@ func (c *Client) handleReconnect() { c.sessMu.Lock() oldControl := c.controlStrm + oldControlStop := c.controlStop oldSess := c.session oldConn := c.conn c.conn = newConn c.session = nil c.controlStrm = nil + c.controlStop = nil c.sessionID = "" c.sessMu.Unlock() + if oldControlStop != nil { + oldControlStop() + } if oldControl != nil { _ = oldControl.Close() } @@ -364,15 +379,25 @@ func (c *Client) handleReconnect() { attemptDelay = 300 * time.Millisecond ) for attempt := 1; attempt <= maxAttempts; attempt++ { - if c.tryReopenSession(attempt) { - return + if c.tryReopenSession(ctx, cfg, cancel, attempt) { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(attemptDelay): } - time.Sleep(attemptDelay) } logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts) + return false } -func (c *Client) tryReopenSession(attempt int) bool { +func (c *Client) tryReopenSession( + ctx context.Context, + cfg Config, + cancel context.CancelFunc, + attempt int, +) bool { conn := muxconn.New(c.ln, c.cipher) c.sessMu.Lock() @@ -400,19 +425,69 @@ func (c *Client) tryReopenSession(attempt int) bool { c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + c.startControlLoop(ctx, cfg, cancel, control) return true } +func (c *Client) startControlLoop( + ctx context.Context, + cfg Config, + cancel context.CancelFunc, + stream *smux.Stream, +) { + controlCtx, stop := context.WithCancel(ctx) + c.sessMu.Lock() + c.controlStop = stop + c.sessMu.Unlock() + + liveness := cfg.Liveness + onPong := liveness.OnPong + onUnhealthy := liveness.OnUnhealthy + liveness.OnPong = func(h control.Health) { + c.sessMu.RLock() + sid := c.sessionID + c.sessMu.RUnlock() + logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) + if onPong != nil { + onPong(h) + } + } + liveness.OnUnhealthy = func(missed int) { + logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed) + if onUnhealthy != nil { + onUnhealthy(missed) + } + } + + go func() { + err := control.Run(controlCtx, stream, liveness) + if controlCtx.Err() != nil || ctx.Err() != nil { + return + } + if err != nil { + logger.Warnf("client control stream ended: %v", err) + } + if !c.handleReconnect(ctx, cfg, cancel) { + cancel() + } + }() +} + func (c *Client) shutdown() { c.sessMu.Lock() control := c.controlStrm + controlStop := c.controlStop sess := c.session conn := c.conn c.controlStrm = nil + c.controlStop = nil c.session = nil c.conn = nil c.sessMu.Unlock() + if controlStop != nil { + controlStop() + } if conn != nil { _ = conn.Close() } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 48976fe..f5d836b 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/xtaci/smux" @@ -517,3 +518,70 @@ func TestShutdownClosesLinkAndConn(t *testing.T) { t.Fatal("shutdown() did not close link") } } + +func TestStartControlLoopReportsPong(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + peerStreamCh := make(chan *smux.Stream, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + peerStreamCh <- stream + } + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + peerStream := <-peerStreamCh + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + got := make(chan control.Health, 1) + c := &Client{sessionID: "sid-control"} + c.startControlLoop(ctx, Config{ + Liveness: control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h control.Health) { + select { + case got <- h: + default: + } + }, + }, + }, cancel, stream) + go func() { + _ = control.Run(ctx, peerStream, control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + }) + }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for control pong") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 5fe206c..9524363 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,7 @@ type File struct { Video Video `yaml:"video"` VP8 VP8 `yaml:"vp8"` SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` Gen Gen `yaml:"gen"` Profiles []Profile `yaml:"profiles"` Failover Failover `yaml:"failover"` @@ -51,17 +52,18 @@ type File struct { // Profile is a failover entry that overrides top-level runtime fields. type Profile struct { - Name string `yaml:"name"` - Link string `yaml:"link"` - Auth Auth `yaml:"auth"` - Room Room `yaml:"room"` - Crypto Crypto `yaml:"crypto"` - Net Net `yaml:"net"` - SOCKS SOCKS `yaml:"socks"` - Engine Engine `yaml:"engine"` - Video Video `yaml:"video"` - VP8 VP8 `yaml:"vp8"` - SEI SEI `yaml:"sei"` + Name string `yaml:"name"` + Link string `yaml:"link"` + Auth Auth `yaml:"auth"` + Room Room `yaml:"room"` + Crypto Crypto `yaml:"crypto"` + Net Net `yaml:"net"` + SOCKS SOCKS `yaml:"socks"` + Engine Engine `yaml:"engine"` + Video Video `yaml:"video"` + VP8 VP8 `yaml:"vp8"` + SEI SEI `yaml:"sei"` + Liveness Liveness `yaml:"liveness"` } // Failover controls ordered profile failover. @@ -137,6 +139,13 @@ type SEI struct { AckTimeoutMS int `yaml:"ack_timeout_ms"` } +// Liveness tunes the post-handshake control stream ping/pong checks. +type Liveness struct { + Interval string `yaml:"interval"` + Timeout string `yaml:"timeout"` + Failures int `yaml:"failures"` +} + // Gen controls room-generation mode. type Gen struct { Amount int `yaml:"amount"` @@ -248,6 +257,9 @@ func Apply(dst session.Config, f File) session.Config { dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize) dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize) dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS) + dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval) + dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout) + dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures) dst.Amount = pickInt(dst.Amount, f.Gen.Amount) return dst } @@ -286,6 +298,9 @@ func ApplyProfile(base session.Config, p Profile) session.Config { dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize) dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize) dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS) + dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval) + dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout) + dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures) return dst } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7504110..b41604c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,6 +39,10 @@ socks: vp8: fps: 25 batch_size: 4 +liveness: + interval: 2s + timeout: 500ms + failures: 4 gen: amount: 3 debug: true @@ -76,20 +80,23 @@ func requireLoadedFile(t *testing.T, f File) { func requireAppliedConfig(t *testing.T, got session.Config) { t.Helper() want := session.Config{ - Mode: testModeSrv, - Link: "direct", - Auth: testAuthProvider, - RoomID: testRoomID, - KeyHex: testCryptoKey, - Transport: "datachannel", - DNSServer: "1.1.1.1:53", - SOCKSHost: "127.0.0.1", - SOCKSPort: 1080, - SOCKSUser: "u", - SOCKSPass: "p", - VP8FPS: 25, - VP8BatchSize: 4, - Amount: 3, + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + LivenessInterval: "2s", + LivenessTimeout: "500ms", + LivenessFailures: 4, + Amount: 3, } if got != want { t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) @@ -132,6 +139,10 @@ crypto: key: shared-key net: dns: 1.1.1.1:53 +liveness: + interval: 5s + timeout: 2s + failures: 5 profiles: - name: wb-vp8 auth: @@ -142,6 +153,8 @@ profiles: transport: vp8channel vp8: fps: 30 + liveness: + interval: 1s - name: jitsi-dc auth: provider: jitsi @@ -174,7 +187,8 @@ failover: if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" { t.Fatalf("first profile = %+v", first) } - if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 { + if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 || + first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 { t.Fatalf("first inherited/overlaid fields = %+v", first) } second := ApplyProfile(base, f.Profiles[1]) @@ -182,6 +196,9 @@ failover: second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" { t.Fatalf("second profile = %+v", second) } + if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 { + t.Fatalf("second liveness fields = %+v", second) + } } func TestLoadProfileCryptoKeyFile(t *testing.T) { diff --git a/internal/control/control.go b/internal/control/control.go new file mode 100644 index 0000000..a6bd50f --- /dev/null +++ b/internal/control/control.go @@ -0,0 +1,321 @@ +// Package control implements the post-handshake control stream protocol. +// +// The control stream is the first smux stream after the olcrtc handshake. It +// stays inside the encrypted muxconn path, so ping/pong proves that the actual +// tunnel path still round-trips, not merely that the provider connection is up. +// +// Wire format matches the handshake framing: a 4-byte big-endian length +// followed by a JSON message. +// +//nolint:tagliatelle // JSON keys are the stable wire protocol schema. +package control + +import ( + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "time" +) + +const ( + // ProtoVersion identifies the control stream wire format. + ProtoVersion = 1 + // MaxMessageSize caps one control frame. + MaxMessageSize = 16 * 1024 + // DefaultInterval is the default interval between ping probes. + DefaultInterval = 10 * time.Second + // DefaultTimeout is the default time to wait for a pong. + DefaultTimeout = 5 * time.Second + // DefaultFailures is the default number of consecutive missed pongs before + // the stream is marked unhealthy. + DefaultFailures = 3 +) + +// MsgType labels a control message. +type MsgType string + +const ( + // TypePing is sent periodically to prove control-stream liveness. + TypePing MsgType = "CONTROL_PING" + // TypePong replies to a ping with the same sequence and timestamp. + TypePong MsgType = "CONTROL_PONG" +) + +var ( + // ErrUnhealthy is returned when the stream misses too many pong replies. + ErrUnhealthy = errors.New("control stream unhealthy") + // ErrProtocolVersion is returned when the peer announces an incompatible version. + ErrProtocolVersion = errors.New("incompatible control protocol version") + // ErrUnexpectedMessage is returned for unknown or malformed control message types. + ErrUnexpectedMessage = errors.New("unexpected control message") + // ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize]. + ErrFrameTooLarge = errors.New("control frame too large") +) + +// Message is one control-stream frame. +type Message struct { + Version int `json:"version"` + Type MsgType `json:"type"` + Seq uint64 `json:"seq,omitempty"` + SentUnixNano int64 `json:"sent_unix_nano,omitempty"` +} + +// Health is reported when a ping round trip completes. +type Health struct { + Seq uint64 + RTT time.Duration + LastSeen time.Time +} + +// Config controls the liveness loop. +type Config struct { + Interval time.Duration + Timeout time.Duration + Failures int + + // OnPong is called after a matching pong is received. + OnPong func(Health) + // OnUnhealthy is called before Run returns [ErrUnhealthy]. + OnUnhealthy func(missed int) +} + +func (cfg Config) withDefaults() Config { + if cfg.Interval <= 0 { + cfg.Interval = DefaultInterval + } + if cfg.Timeout <= 0 { + cfg.Timeout = DefaultTimeout + } + if cfg.Failures <= 0 { + cfg.Failures = DefaultFailures + } + return cfg +} + +// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes, +// or the configured failure threshold is reached. +func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error { + cfg = cfg.withDefaults() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + state := &state{ + rw: rw, + cfg: cfg, + pending: make(map[uint64]time.Time), + now: time.Now, + out: make(chan Message, 16), + } + + errCh := make(chan error, 3) + go func() { + <-ctx.Done() + _ = rw.Close() + }() + go func() { errCh <- state.readLoop(ctx) }() + go func() { errCh <- state.probeLoop(ctx) }() + go func() { errCh <- state.writeLoop(ctx) }() + + err := <-errCh + cancel() + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + return err +} + +type state struct { + rw io.ReadWriteCloser + cfg Config + now func() time.Time + + out chan Message + + mu sync.Mutex + pending map[uint64]time.Time + nextSeq uint64 + failures int +} + +func (s *state) readLoop(ctx context.Context) error { + for { + raw, err := readFrame(s.rw) + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + msg, err := parseMessage(raw) + if err != nil { + return err + } + switch msg.Type { + case TypePing: + if err := s.enqueue(ctx, Message{ + Version: ProtoVersion, + Type: TypePong, + Seq: msg.Seq, + SentUnixNano: msg.SentUnixNano, + }); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + case TypePong: + s.handlePong(msg) + default: + return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + } + } +} + +func (s *state) probeLoop(ctx context.Context) error { + ticker := time.NewTicker(s.cfg.Interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if err := s.sendProbe(ctx); err != nil { + return err + } + } + } +} + +func (s *state) sendProbe(ctx context.Context) error { + now := s.now() + + s.mu.Lock() + for seq, sent := range s.pending { + if now.Sub(sent) < s.cfg.Timeout { + continue + } + delete(s.pending, seq) + s.failures++ + } + if s.failures >= s.cfg.Failures { + missed := s.failures + s.mu.Unlock() + if s.cfg.OnUnhealthy != nil { + s.cfg.OnUnhealthy(missed) + } + return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed) + } + + s.nextSeq++ + seq := s.nextSeq + s.pending[seq] = now + s.mu.Unlock() + + return s.enqueue(ctx, Message{ + Version: ProtoVersion, + Type: TypePing, + Seq: seq, + SentUnixNano: now.UnixNano(), + }) +} + +func (s *state) handlePong(msg Message) { + now := s.now() + + s.mu.Lock() + sent, ok := s.pending[msg.Seq] + if ok { + delete(s.pending, msg.Seq) + s.failures = 0 + } + s.mu.Unlock() + + if !ok || s.cfg.OnPong == nil { + return + } + s.cfg.OnPong(Health{ + Seq: msg.Seq, + RTT: now.Sub(sent), + LastSeen: now, + }) +} + +func (s *state) enqueue(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.out <- msg: + return nil + } +} + +func (s *state) writeLoop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg := <-s.out: + if err := writeFrame(s.rw, msg); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + } + } +} + +func parseMessage(raw []byte) (Message, error) { + var msg Message + if err := json.Unmarshal(raw, &msg); err != nil { + return Message{}, fmt.Errorf("parse control message: %w", err) + } + if msg.Version != ProtoVersion { + return Message{}, fmt.Errorf("%w: peer v%d, local v%d", + ErrProtocolVersion, msg.Version, ProtoVersion) + } + if msg.Type != TypePing && msg.Type != TypePong { + return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) + } + return msg, nil +} + +func writeFrame(w io.Writer, msg Message) error { + body, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal control message: %w", err) + } + if len(body) > MaxMessageSize { + return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize) + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize + if _, err := w.Write(hdr[:]); err != nil { + return fmt.Errorf("write control hdr: %w", err) + } + if _, err := w.Write(body); err != nil { + return fmt.Errorf("write control body: %w", err) + } + return nil +} + +func readFrame(r io.Reader) ([]byte, error) { + var hdr [4]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("read control hdr: %w", err) + } + n := binary.BigEndian.Uint32(hdr[:]) + if n > MaxMessageSize { + return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, fmt.Errorf("read control body: %w", err) + } + return buf, nil +} diff --git a/internal/control/control_test.go b/internal/control/control_test.go new file mode 100644 index 0000000..3c52bf6 --- /dev/null +++ b/internal/control/control_test.go @@ -0,0 +1,128 @@ +package control + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + "time" +) + +func controlPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + a, b := net.Pipe() + t.Cleanup(func() { + _ = a.Close() + _ = b.Close() + }) + return a, b +} + +func TestRunPingPongReportsRTT(t *testing.T) { + a, b := controlPair(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + got := make(chan Health, 1) + cfg := Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h Health) { + select { + case got <- h: + default: + } + }, + } + errCh := make(chan error, 2) + go func() { errCh <- Run(ctx, a, cfg) }() + go func() { errCh <- Run(ctx, b, cfg) }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + if h.RTT < 0 { + t.Fatalf("Health.RTT = %v", h.RTT) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for pong health") + } + + cancel() + for range 2 { + if err := <-errCh; err != nil { + t.Fatalf("Run() after cancel = %v", err) + } + } +} + +func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) { + a, b := controlPair(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _, _ = io.Copy(io.Discard, b) + }() + + missedCh := make(chan int, 1) + errCh := make(chan error, 1) + go func() { + errCh <- Run(ctx, a, Config{ + Interval: 10 * time.Millisecond, + Timeout: 5 * time.Millisecond, + Failures: 2, + OnUnhealthy: func(missed int) { missedCh <- missed }, + }) + }() + + select { + case err := <-errCh: + if !errors.Is(err, ErrUnhealthy) { + t.Fatalf("Run() error = %v, want ErrUnhealthy", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for unhealthy result") + } + if missed := <-missedCh; missed < 2 { + t.Fatalf("missed = %d, want >= 2", missed) + } +} + +func TestRunRejectsBadProtocolVersion(t *testing.T) { + a, b := controlPair(t) + errCh := make(chan error, 1) + go func() { + errCh <- Run(context.Background(), a, Config{Interval: time.Hour}) + }() + if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil { + t.Fatalf("writeFrame() error = %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, ErrProtocolVersion) { + t.Fatalf("Run() error = %v, want ErrProtocolVersion", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for protocol error") + } +} + +func TestReadFrameRejectsTooLarge(t *testing.T) { + a, b := controlPair(t) + go func() { + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1) + _, _ = b.Write(hdr[:]) + }() + _, err := readFrame(a) + if !errors.Is(err, ErrFrameTooLarge) { + t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err) + } +} diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index bec84a7..5d34f6f 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -13,8 +13,8 @@ // │ │ // // After the exchange the control stream stays open; tunnel traffic flows over -// additional smux streams opened by the client. The control stream may carry -// keepalives or future control messages. +// additional smux streams opened by the client. The control stream then +// carries ping/pong liveness and future control messages. // //nolint:tagliatelle // JSON keys are the stable wire protocol schema. package handshake diff --git a/internal/server/server.go b/internal/server/server.go index a720a25..4954ad4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/handshake" "github.com/openlibrecommunity/olcrtc/internal/link" @@ -55,6 +56,7 @@ type Server struct { cipher *crypto.Cipher conn *muxconn.Conn session *smux.Session + controlStop context.CancelFunc sessMu sync.RWMutex reinstallMu sync.Mutex wg sync.WaitGroup @@ -68,6 +70,7 @@ type Server struct { resolver *net.Resolver socksProxyAddr string socksProxyPort int + liveness control.Config } // ConnectRequest is a message from the client to establish a new connection. @@ -106,6 +109,7 @@ type Config struct { Engine string URL string Token string + Liveness control.Config // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. @@ -155,6 +159,7 @@ func Run(ctx context.Context, cfg Config) error { dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, + liveness: cfg.Liveness, } s.setupResolver() @@ -340,13 +345,18 @@ func (s *Server) reinstallSession(dead *smux.Session) { } oldSess := s.session oldConn := s.conn + oldControlStop := s.controlStop oldSID := s.sessionID s.session = newSess s.conn = newConn + s.controlStop = nil s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + if oldControlStop != nil { + oldControlStop() + } if oldSess != nil { _ = oldSess.Close() } @@ -362,13 +372,18 @@ func (s *Server) closeSession() { s.sessMu.Lock() sess := s.session conn := s.conn + controlStop := s.controlStop s.session = nil s.conn = nil + s.controlStop = nil oldSID := s.sessionID s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + if controlStop != nil { + controlStop() + } if conn != nil { _ = conn.Close() } @@ -478,26 +493,48 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { s.sessMu.Unlock() s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) - // The control stream stays open for the lifetime of the session; - // keep it parked in a goroutine so the smux session does not close it. - s.wg.Add(1) - go func() { - defer s.wg.Done() - s.parkControlStream(stream) - }() + s.startControlLoop(ctx, sess, stream) return true } -// parkControlStream blocks reading from the control stream until it closes. -// Future control messages (kick, rate updates, etc.) would be dispatched here. -func (s *Server) parkControlStream(stream *smux.Stream) { - defer func() { _ = stream.Close() }() - buf := make([]byte, 64) - for { - if _, err := stream.Read(buf); err != nil { - return +func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) { + controlCtx, stop := context.WithCancel(ctx) + s.sessMu.Lock() + s.controlStop = stop + s.sessMu.Unlock() + + liveness := s.liveness + onPong := liveness.OnPong + onUnhealthy := liveness.OnUnhealthy + liveness.OnPong = func(h control.Health) { + s.sessMu.RLock() + sid := s.sessionID + s.sessMu.RUnlock() + logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq) + if onPong != nil { + onPong(h) } } + liveness.OnUnhealthy = func(missed int) { + logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed) + if onUnhealthy != nil { + onUnhealthy(missed) + } + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer func() { _ = stream.Close() }() + err := control.Run(controlCtx, stream, liveness) + if controlCtx.Err() != nil || ctx.Err() != nil { + return + } + if err != nil { + logger.Warnf("server control stream ended: %v", err) + } + s.reinstallSession(sess) + }() } func (s *Server) shutdown() { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f6034bf..d5a6f6d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/openlibrecommunity/olcrtc/internal/control" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/xtaci/smux" @@ -373,6 +374,77 @@ func TestReinstallSessionFiresOnClose(t *testing.T) { } } +func TestStartControlLoopReportsPong(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + serverStreamCh := make(chan *smux.Stream, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + serverStreamCh <- stream + } + }() + + clientStream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + serverStream := <-serverStreamCh + + ctx, cancel := context.WithCancel(context.Background()) + got := make(chan control.Health, 1) + s := &Server{ + sessionID: "sid-control", + liveness: control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + OnPong: func(h control.Health) { + select { + case got <- h: + default: + } + }, + }, + } + defer func() { + cancel() + s.wg.Wait() + }() + s.startControlLoop(ctx, serverSess, serverStream) + go func() { + _ = control.Run(ctx, clientStream, control.Config{ + Interval: 10 * time.Millisecond, + Timeout: 100 * time.Millisecond, + Failures: 2, + }) + }() + + select { + case h := <-got: + if h.Seq == 0 { + t.Fatal("Health.Seq = 0") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for control pong") + } +} + //nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together. func TestDispatchFiresOnTraffic(t *testing.T) { var lc net.ListenConfig