mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
feat: add control stream liveness
This commit is contained in:
@@ -140,6 +140,7 @@ func runWithConfig(cfg loadedConfig) error {
|
||||
return fmt.Errorf("validate config: %w", err)
|
||||
}
|
||||
scfg = session.ApplyTransportDefaults(scfg)
|
||||
scfg = session.ApplyLivenessDefaults(scfg)
|
||||
|
||||
if scfg.Mode == modeGen {
|
||||
if len(cfg.profiles) > 0 {
|
||||
@@ -166,7 +167,7 @@ func prepareProfiles(profiles []supervisor.Profile) ([]supervisor.Profile, error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate profile %q: %w", profile.Name, err)
|
||||
}
|
||||
profile.Config = session.ApplyTransportDefaults(scfg)
|
||||
profile.Config = session.ApplyLivenessDefaults(session.ApplyTransportDefaults(scfg))
|
||||
out = append(out, profile)
|
||||
}
|
||||
return out, nil
|
||||
|
||||
@@ -21,6 +21,11 @@ net:
|
||||
transport: datachannel # must match the server
|
||||
dns: "8.8.8.8:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
# Local SOCKS5 listener exposed to applications
|
||||
socks:
|
||||
host: "127.0.0.1"
|
||||
|
||||
@@ -31,6 +31,9 @@ olcrtc /etc/olcrtc/server.yaml
|
||||
| `video.*` | videochannel tuning |
|
||||
| `vp8.*` | vp8channel tuning |
|
||||
| `sei.fps` / `.batch_size` / `.fragment_size` / `.ack_timeout_ms` | seichannel tuning |
|
||||
| `liveness.interval` | control-stream ping interval, default `10s` |
|
||||
| `liveness.timeout` | pong timeout, default `5s` |
|
||||
| `liveness.failures` | missed pongs before reconnect, default `3` |
|
||||
| `gen.amount` | gen mode: number of rooms to create |
|
||||
| `profiles[]` | ordered srv/cnc failover profiles |
|
||||
| `failover.retry_delay` | delay before trying the next profile, e.g. `2s` |
|
||||
@@ -45,6 +48,25 @@ olcrtc /etc/olcrtc/server.yaml
|
||||
`crypto.key_file` is resolved relative to the YAML file. Do not set it
|
||||
together with `crypto.key`.
|
||||
|
||||
## Liveness
|
||||
|
||||
After `CLIENT_HELLO` / `SERVER_WELCOME`, the first smux stream stays open as
|
||||
an encrypted control stream. olcrtc now sends `CONTROL_PING` / `CONTROL_PONG`
|
||||
messages over that stream to prove the real tunnel path still round-trips.
|
||||
This detects states where a provider or WebRTC layer looks connected but the
|
||||
encrypted smux path is no longer usable.
|
||||
|
||||
```yaml
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
```
|
||||
|
||||
When the failure threshold is reached, the current smux session is rebuilt.
|
||||
In failover mode, a profile that exits after liveness-triggered reconnect
|
||||
failure lets the supervisor advance to the next profile.
|
||||
|
||||
## Failover Profiles
|
||||
|
||||
`mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used
|
||||
|
||||
@@ -10,6 +10,11 @@ crypto:
|
||||
net:
|
||||
dns: "1.1.1.1:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
data: data
|
||||
|
||||
profiles:
|
||||
|
||||
@@ -72,6 +72,7 @@ Important fields:
|
||||
| `net.dns` | `DNSServer` | Resolver used by server-side target dials and provider HTTP where wired. |
|
||||
| `socks.*` | SOCKS fields | Client listener and optional server egress proxy. |
|
||||
| `engine.*` | direct engine fields | Used only with `auth.provider: none`. |
|
||||
| `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. |
|
||||
|
||||
`internal/app/session` is the main router:
|
||||
|
||||
@@ -151,6 +152,18 @@ SERVER_REJECT { version, reason }
|
||||
|
||||
The handshake has a 64 KiB frame cap and a default 15 second timeout.
|
||||
|
||||
After handshake, `internal/control` keeps that same encrypted smux stream open
|
||||
and exchanges length-prefixed JSON control messages:
|
||||
|
||||
```text
|
||||
CONTROL_PING { version, seq, sent_unix_nano }
|
||||
CONTROL_PONG { version, seq, sent_unix_nano }
|
||||
```
|
||||
|
||||
Defaults are `liveness.interval: 10s`, `liveness.timeout: 5s`, and
|
||||
`liveness.failures: 3`. Missed pongs mark the smux session unhealthy and
|
||||
trigger a session rebuild/reconnect path.
|
||||
|
||||
## Registries And Plugin Shape
|
||||
|
||||
The universal-carrier refactor centers on small registries:
|
||||
@@ -320,9 +333,9 @@ adaptive instead of static YAML knobs.
|
||||
|
||||
### 3. Control Stream Protocol
|
||||
|
||||
The first smux stream is parked after handshake. It is the natural place for:
|
||||
The first smux stream now carries control ping/pong after handshake. It is
|
||||
still the natural place for:
|
||||
|
||||
- Ping/pong and peer liveness.
|
||||
- Server policy updates.
|
||||
- Graceful reconnect notifications.
|
||||
- Drain/start markers for failover.
|
||||
@@ -330,7 +343,7 @@ The first smux stream is parked after handshake. It is the natural place for:
|
||||
|
||||
Likely files:
|
||||
|
||||
- `internal/handshake`
|
||||
- `internal/control`
|
||||
- `internal/server`
|
||||
- `internal/client`
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ net:
|
||||
transport: datachannel # datachannel | videochannel | seichannel | vp8channel
|
||||
dns: "8.8.8.8:53"
|
||||
|
||||
liveness:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
failures: 3
|
||||
|
||||
# Outbound SOCKS5 proxy for server-side egress (optional)
|
||||
socks:
|
||||
proxy_addr: "" # e.g. "127.0.0.1"
|
||||
|
||||
@@ -63,12 +63,20 @@
|
||||
| `profiles` | Список профилей failover для `srv`/`cnc` |
|
||||
| `failover.retry_delay` | Пауза перед следующим профилем, например `2s` |
|
||||
| `failover.max_cycles` | Сколько полных проходов по профилям сделать; `0` = бесконечно |
|
||||
| `liveness.interval` | Интервал ping по control stream, по умолчанию `10s` |
|
||||
| `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` |
|
||||
| `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` |
|
||||
|
||||
`crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно.
|
||||
|
||||
Если задан `profiles`, поля верхнего уровня становятся общими defaults, а
|
||||
каждый профиль переопределяет только свои `auth`, `room`, `net`, `engine` и
|
||||
настройки транспорта. Порядок профилей должен совпадать на сервере и клиенте.
|
||||
настройки транспорта/liveness. Порядок профилей должен совпадать на сервере и
|
||||
клиенте.
|
||||
|
||||
`liveness` проверяет именно зашифрованный smux control stream после handshake,
|
||||
а не только статус WebRTC/provider соединения. Если pong не приходит несколько
|
||||
раз подряд, текущая smux-сессия пересоздается.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier/builtin"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link/direct"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
@@ -120,43 +121,56 @@ var (
|
||||
// ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication.
|
||||
ErrSOCKSAuthRequired = errors.New(
|
||||
"socks auth required when binding outside loopback (set socks.user and socks.pass)")
|
||||
|
||||
// ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration.
|
||||
ErrLivenessIntervalInvalid = errors.New(
|
||||
"invalid liveness interval (set liveness.interval to a duration > 0)")
|
||||
// ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration.
|
||||
ErrLivenessTimeoutInvalid = errors.New(
|
||||
"invalid liveness timeout (set liveness.timeout to a duration > 0)")
|
||||
// ErrLivenessFailuresInvalid indicates that liveness.failures is not positive.
|
||||
ErrLivenessFailuresInvalid = errors.New(
|
||||
"invalid liveness failures (set liveness.failures to a value > 0)")
|
||||
)
|
||||
|
||||
// Config holds runtime session settings.
|
||||
type Config struct {
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Amount int
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
LivenessInterval string
|
||||
LivenessTimeout string
|
||||
LivenessFailures int
|
||||
Amount int
|
||||
}
|
||||
|
||||
// RegisterDefaults registers built-in carriers and transports.
|
||||
@@ -212,6 +226,20 @@ func ApplyTransportDefaults(cfg Config) Config {
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyLivenessDefaults fills documented control-stream liveness defaults.
|
||||
func ApplyLivenessDefaults(cfg Config) Config {
|
||||
if cfg.LivenessInterval == "" {
|
||||
cfg.LivenessInterval = control.DefaultInterval.String()
|
||||
}
|
||||
if cfg.LivenessTimeout == "" {
|
||||
cfg.LivenessTimeout = control.DefaultTimeout.String()
|
||||
}
|
||||
if cfg.LivenessFailures == 0 {
|
||||
cfg.LivenessFailures = control.DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyVideoDefaults(cfg Config) Config {
|
||||
if cfg.VideoCodec == "" {
|
||||
cfg.VideoCodec = videoCodecQRCode
|
||||
@@ -292,6 +320,9 @@ func Validate(cfg Config) error {
|
||||
if err := validateTransportConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateLivenessConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateModeConfig(cfg)
|
||||
}
|
||||
|
||||
@@ -431,6 +462,52 @@ func validateModeConfig(cfg Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLivenessConfig(cfg Config) error {
|
||||
if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
if cfg.LivenessFailures < 0 {
|
||||
return ErrLivenessFailuresInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) {
|
||||
if value == "" {
|
||||
return def, nil
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be > 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func livenessConfig(cfg Config) (control.Config, error) {
|
||||
interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
failures := cfg.LivenessFailures
|
||||
if failures == 0 {
|
||||
failures = control.DefaultFailures
|
||||
}
|
||||
if failures < 0 {
|
||||
return control.Config{}, ErrLivenessFailuresInvalid
|
||||
}
|
||||
return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil
|
||||
}
|
||||
|
||||
func isLoopbackListenHost(host string) bool {
|
||||
if host == "localhost" {
|
||||
return true
|
||||
@@ -442,7 +519,12 @@ func isLoopbackListenHost(host string) bool {
|
||||
// Run starts the configured mode.
|
||||
func Run(ctx context.Context, cfg Config) error {
|
||||
cfg = ApplyTransportDefaults(cfg)
|
||||
cfg = ApplyLivenessDefaults(cfg)
|
||||
roomURL := cfg.RoomID
|
||||
liveness, err := livenessConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch cfg.Mode {
|
||||
case modeSRV:
|
||||
@@ -474,6 +556,7 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) {
|
||||
logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims)
|
||||
},
|
||||
@@ -517,6 +600,7 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("client: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
)
|
||||
|
||||
func TestApplyTransportDefaults(t *testing.T) {
|
||||
@@ -85,6 +87,24 @@ func TestApplyTransportDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLivenessDefaults(t *testing.T) {
|
||||
got := ApplyLivenessDefaults(Config{})
|
||||
if got.LivenessInterval != control.DefaultInterval.String() {
|
||||
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
|
||||
}
|
||||
if got.LivenessTimeout != control.DefaultTimeout.String() {
|
||||
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
|
||||
}
|
||||
if got.LivenessFailures != control.DefaultFailures {
|
||||
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
|
||||
}
|
||||
|
||||
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
|
||||
if got := ApplyLivenessDefaults(explicit); got != explicit {
|
||||
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:maintidx // table-driven validation test naturally has many cases
|
||||
func TestValidate(t *testing.T) {
|
||||
RegisterDefaults()
|
||||
@@ -422,6 +442,33 @@ func TestValidate(t *testing.T) {
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "liveness rejects bad interval",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessInterval = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessIntervalInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects zero timeout",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessTimeout = "0s"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessTimeoutInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects negative failures",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessFailures = -1
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessFailuresInvalid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
@@ -54,7 +55,9 @@ type Client struct {
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStrm *smux.Stream
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reconnectMu sync.Mutex
|
||||
deviceID string
|
||||
sessionID string
|
||||
claims map[string]any
|
||||
@@ -93,6 +96,7 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
|
||||
// DeviceID overrides the persistent client-side device identifier. Leave
|
||||
// empty to derive one from DeviceIDPath (or generate a random one if both
|
||||
@@ -217,7 +221,9 @@ func (c *Client) bringUpLink(
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
c.handleReconnect()
|
||||
if !c.handleReconnect(ctx, cfg, cancel) {
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
|
||||
if err := ln.Connect(ctx); err != nil {
|
||||
@@ -243,14 +249,15 @@ func (c *Client) bringUpLink(
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
|
||||
go ln.WatchConnection(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openControlStream opens stream #1 on sess and performs the handshake.
|
||||
// The stream stays open for the lifetime of the smux session — the server
|
||||
// holds it parked, and it would carry future control messages.
|
||||
// The stream stays open for the lifetime of the smux session and carries
|
||||
// post-handshake control messages.
|
||||
func openControlStream(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
@@ -326,7 +333,10 @@ func smuxConfig() *smux.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect() {
|
||||
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc) bool {
|
||||
c.reconnectMu.Lock()
|
||||
defer c.reconnectMu.Unlock()
|
||||
|
||||
logger.Infof("client link reconnect - tearing down smux session")
|
||||
|
||||
// Install a fresh muxconn immediately so onData never hits nil while
|
||||
@@ -336,14 +346,19 @@ func (c *Client) handleReconnect() {
|
||||
|
||||
c.sessMu.Lock()
|
||||
oldControl := c.controlStrm
|
||||
oldControlStop := c.controlStop
|
||||
oldSess := c.session
|
||||
oldConn := c.conn
|
||||
c.conn = newConn
|
||||
c.session = nil
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.sessionID = ""
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldControl != nil {
|
||||
_ = oldControl.Close()
|
||||
}
|
||||
@@ -364,15 +379,25 @@ func (c *Client) handleReconnect() {
|
||||
attemptDelay = 300 * time.Millisecond
|
||||
)
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if c.tryReopenSession(attempt) {
|
||||
return
|
||||
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(attemptDelay):
|
||||
}
|
||||
time.Sleep(attemptDelay)
|
||||
}
|
||||
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) tryReopenSession(attempt int) bool {
|
||||
func (c *Client) tryReopenSession(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
attempt int,
|
||||
) bool {
|
||||
conn := muxconn.New(c.ln, c.cipher)
|
||||
|
||||
c.sessMu.Lock()
|
||||
@@ -400,19 +425,69 @@ func (c *Client) tryReopenSession(attempt int) bool {
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) startControlLoop(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
stream *smux.Stream,
|
||||
) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
c.sessMu.Lock()
|
||||
c.controlStop = stop
|
||||
c.sessMu.Unlock()
|
||||
|
||||
liveness := cfg.Liveness
|
||||
onPong := liveness.OnPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
c.sessMu.RLock()
|
||||
sid := c.sessionID
|
||||
c.sessMu.RUnlock()
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("client control stream ended: %v", err)
|
||||
}
|
||||
if !c.handleReconnect(ctx, cfg, cancel) {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
control := c.controlStrm
|
||||
controlStop := c.controlStop
|
||||
sess := c.session
|
||||
conn := c.conn
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.session = nil
|
||||
c.conn = nil
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -517,3 +518,70 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
|
||||
t.Fatal("shutdown() did not close link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
peerStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
peerStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
stream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
peerStream := <-peerStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
got := make(chan control.Health, 1)
|
||||
c := &Client{sessionID: "sid-control"}
|
||||
c.startControlLoop(ctx, Config{
|
||||
Liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}, cancel, stream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, peerStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ type File struct {
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
Gen Gen `yaml:"gen"`
|
||||
Profiles []Profile `yaml:"profiles"`
|
||||
Failover Failover `yaml:"failover"`
|
||||
@@ -51,17 +52,18 @@ type File struct {
|
||||
|
||||
// Profile is a failover entry that overrides top-level runtime fields.
|
||||
type Profile struct {
|
||||
Name string `yaml:"name"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Name string `yaml:"name"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
}
|
||||
|
||||
// Failover controls ordered profile failover.
|
||||
@@ -137,6 +139,13 @@ type SEI struct {
|
||||
AckTimeoutMS int `yaml:"ack_timeout_ms"`
|
||||
}
|
||||
|
||||
// Liveness tunes the post-handshake control stream ping/pong checks.
|
||||
type Liveness struct {
|
||||
Interval string `yaml:"interval"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Failures int `yaml:"failures"`
|
||||
}
|
||||
|
||||
// Gen controls room-generation mode.
|
||||
type Gen struct {
|
||||
Amount int `yaml:"amount"`
|
||||
@@ -248,6 +257,9 @@ func Apply(dst session.Config, f File) session.Config {
|
||||
dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval)
|
||||
dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout)
|
||||
dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures)
|
||||
dst.Amount = pickInt(dst.Amount, f.Gen.Amount)
|
||||
return dst
|
||||
}
|
||||
@@ -286,6 +298,9 @@ func ApplyProfile(base session.Config, p Profile) session.Config {
|
||||
dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval)
|
||||
dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout)
|
||||
dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures)
|
||||
return dst
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,10 @@ socks:
|
||||
vp8:
|
||||
fps: 25
|
||||
batch_size: 4
|
||||
liveness:
|
||||
interval: 2s
|
||||
timeout: 500ms
|
||||
failures: 4
|
||||
gen:
|
||||
amount: 3
|
||||
debug: true
|
||||
@@ -76,20 +80,23 @@ func requireLoadedFile(t *testing.T, f File) {
|
||||
func requireAppliedConfig(t *testing.T, got session.Config) {
|
||||
t.Helper()
|
||||
want := session.Config{
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
Amount: 3,
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
LivenessInterval: "2s",
|
||||
LivenessTimeout: "500ms",
|
||||
LivenessFailures: 4,
|
||||
Amount: 3,
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want)
|
||||
@@ -132,6 +139,10 @@ crypto:
|
||||
key: shared-key
|
||||
net:
|
||||
dns: 1.1.1.1:53
|
||||
liveness:
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
failures: 5
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
@@ -142,6 +153,8 @@ profiles:
|
||||
transport: vp8channel
|
||||
vp8:
|
||||
fps: 30
|
||||
liveness:
|
||||
interval: 1s
|
||||
- name: jitsi-dc
|
||||
auth:
|
||||
provider: jitsi
|
||||
@@ -174,7 +187,8 @@ failover:
|
||||
if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" {
|
||||
t.Fatalf("first profile = %+v", first)
|
||||
}
|
||||
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 {
|
||||
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 ||
|
||||
first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 {
|
||||
t.Fatalf("first inherited/overlaid fields = %+v", first)
|
||||
}
|
||||
second := ApplyProfile(base, f.Profiles[1])
|
||||
@@ -182,6 +196,9 @@ failover:
|
||||
second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" {
|
||||
t.Fatalf("second profile = %+v", second)
|
||||
}
|
||||
if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 {
|
||||
t.Fatalf("second liveness fields = %+v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProfileCryptoKeyFile(t *testing.T) {
|
||||
|
||||
321
internal/control/control.go
Normal file
321
internal/control/control.go
Normal file
@@ -0,0 +1,321 @@
|
||||
// Package control implements the post-handshake control stream protocol.
|
||||
//
|
||||
// The control stream is the first smux stream after the olcrtc handshake. It
|
||||
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
|
||||
// tunnel path still round-trips, not merely that the provider connection is up.
|
||||
//
|
||||
// Wire format matches the handshake framing: a 4-byte big-endian length
|
||||
// followed by a JSON message.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProtoVersion identifies the control stream wire format.
|
||||
ProtoVersion = 1
|
||||
// MaxMessageSize caps one control frame.
|
||||
MaxMessageSize = 16 * 1024
|
||||
// DefaultInterval is the default interval between ping probes.
|
||||
DefaultInterval = 10 * time.Second
|
||||
// DefaultTimeout is the default time to wait for a pong.
|
||||
DefaultTimeout = 5 * time.Second
|
||||
// DefaultFailures is the default number of consecutive missed pongs before
|
||||
// the stream is marked unhealthy.
|
||||
DefaultFailures = 3
|
||||
)
|
||||
|
||||
// MsgType labels a control message.
|
||||
type MsgType string
|
||||
|
||||
const (
|
||||
// TypePing is sent periodically to prove control-stream liveness.
|
||||
TypePing MsgType = "CONTROL_PING"
|
||||
// TypePong replies to a ping with the same sequence and timestamp.
|
||||
TypePong MsgType = "CONTROL_PONG"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnhealthy is returned when the stream misses too many pong replies.
|
||||
ErrUnhealthy = errors.New("control stream unhealthy")
|
||||
// ErrProtocolVersion is returned when the peer announces an incompatible version.
|
||||
ErrProtocolVersion = errors.New("incompatible control protocol version")
|
||||
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
|
||||
ErrUnexpectedMessage = errors.New("unexpected control message")
|
||||
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("control frame too large")
|
||||
)
|
||||
|
||||
// Message is one control-stream frame.
|
||||
type Message struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
Seq uint64 `json:"seq,omitempty"`
|
||||
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
|
||||
}
|
||||
|
||||
// Health is reported when a ping round trip completes.
|
||||
type Health struct {
|
||||
Seq uint64
|
||||
RTT time.Duration
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Config controls the liveness loop.
|
||||
type Config struct {
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
Failures int
|
||||
|
||||
// OnPong is called after a matching pong is received.
|
||||
OnPong func(Health)
|
||||
// OnUnhealthy is called before Run returns [ErrUnhealthy].
|
||||
OnUnhealthy func(missed int)
|
||||
}
|
||||
|
||||
func (cfg Config) withDefaults() Config {
|
||||
if cfg.Interval <= 0 {
|
||||
cfg.Interval = DefaultInterval
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = DefaultTimeout
|
||||
}
|
||||
if cfg.Failures <= 0 {
|
||||
cfg.Failures = DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
|
||||
// or the configured failure threshold is reached.
|
||||
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
|
||||
cfg = cfg.withDefaults()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
state := &state{
|
||||
rw: rw,
|
||||
cfg: cfg,
|
||||
pending: make(map[uint64]time.Time),
|
||||
now: time.Now,
|
||||
out: make(chan Message, 16),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 3)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = rw.Close()
|
||||
}()
|
||||
go func() { errCh <- state.readLoop(ctx) }()
|
||||
go func() { errCh <- state.probeLoop(ctx) }()
|
||||
go func() { errCh <- state.writeLoop(ctx) }()
|
||||
|
||||
err := <-errCh
|
||||
cancel()
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type state struct {
|
||||
rw io.ReadWriteCloser
|
||||
cfg Config
|
||||
now func() time.Time
|
||||
|
||||
out chan Message
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[uint64]time.Time
|
||||
nextSeq uint64
|
||||
failures int
|
||||
}
|
||||
|
||||
func (s *state) readLoop(ctx context.Context) error {
|
||||
for {
|
||||
raw, err := readFrame(s.rw)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
msg, err := parseMessage(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Type {
|
||||
case TypePing:
|
||||
if err := s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePong,
|
||||
Seq: msg.Seq,
|
||||
SentUnixNano: msg.SentUnixNano,
|
||||
}); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
case TypePong:
|
||||
s.handlePong(msg)
|
||||
default:
|
||||
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) probeLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(s.cfg.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if err := s.sendProbe(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) sendProbe(ctx context.Context) error {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
for seq, sent := range s.pending {
|
||||
if now.Sub(sent) < s.cfg.Timeout {
|
||||
continue
|
||||
}
|
||||
delete(s.pending, seq)
|
||||
s.failures++
|
||||
}
|
||||
if s.failures >= s.cfg.Failures {
|
||||
missed := s.failures
|
||||
s.mu.Unlock()
|
||||
if s.cfg.OnUnhealthy != nil {
|
||||
s.cfg.OnUnhealthy(missed)
|
||||
}
|
||||
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
|
||||
}
|
||||
|
||||
s.nextSeq++
|
||||
seq := s.nextSeq
|
||||
s.pending[seq] = now
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePing,
|
||||
Seq: seq,
|
||||
SentUnixNano: now.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) handlePong(msg Message) {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
sent, ok := s.pending[msg.Seq]
|
||||
if ok {
|
||||
delete(s.pending, msg.Seq)
|
||||
s.failures = 0
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || s.cfg.OnPong == nil {
|
||||
return
|
||||
}
|
||||
s.cfg.OnPong(Health{
|
||||
Seq: msg.Seq,
|
||||
RTT: now.Sub(sent),
|
||||
LastSeen: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) enqueue(ctx context.Context, msg Message) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.out <- msg:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) writeLoop(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg := <-s.out:
|
||||
if err := writeFrame(s.rw, msg); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessage(raw []byte) (Message, error) {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
return Message{}, fmt.Errorf("parse control message: %w", err)
|
||||
}
|
||||
if msg.Version != ProtoVersion {
|
||||
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
|
||||
ErrProtocolVersion, msg.Version, ProtoVersion)
|
||||
}
|
||||
if msg.Type != TypePing && msg.Type != TypePong {
|
||||
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg Message) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal control message: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write control hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write control body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read control hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read control body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
128
internal/control/control_test.go
Normal file
128
internal/control/control_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func controlPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
a, b := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
return a, b
|
||||
}
|
||||
|
||||
func TestRunPingPongReportsRTT(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
got := make(chan Health, 1)
|
||||
cfg := Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- Run(ctx, a, cfg) }()
|
||||
go func() { errCh <- Run(ctx, b, cfg) }()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
if h.RTT < 0 {
|
||||
t.Fatalf("Health.RTT = %v", h.RTT)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for pong health")
|
||||
}
|
||||
|
||||
cancel()
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("Run() after cancel = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, b)
|
||||
}()
|
||||
|
||||
missedCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(ctx, a, Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnUnhealthy: func(missed int) { missedCh <- missed },
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrUnhealthy) {
|
||||
t.Fatalf("Run() error = %v, want ErrUnhealthy", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for unhealthy result")
|
||||
}
|
||||
if missed := <-missedCh; missed < 2 {
|
||||
t.Fatalf("missed = %d, want >= 2", missed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsBadProtocolVersion(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
|
||||
}()
|
||||
if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil {
|
||||
t.Fatalf("writeFrame() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrProtocolVersion) {
|
||||
t.Fatalf("Run() error = %v, want ErrProtocolVersion", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsTooLarge(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
go func() {
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1)
|
||||
_, _ = b.Write(hdr[:])
|
||||
}()
|
||||
_, err := readFrame(a)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err)
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,8 @@
|
||||
// │ │
|
||||
//
|
||||
// After the exchange the control stream stays open; tunnel traffic flows over
|
||||
// additional smux streams opened by the client. The control stream may carry
|
||||
// keepalives or future control messages.
|
||||
// additional smux streams opened by the client. The control stream then
|
||||
// carries ping/pong liveness and future control messages.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package handshake
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
@@ -55,6 +56,7 @@ type Server struct {
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
@@ -68,6 +70,7 @@ type Server struct {
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
socksProxyPort int
|
||||
liveness control.Config
|
||||
}
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
@@ -106,6 +109,7 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
@@ -155,6 +159,7 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
liveness: cfg.Liveness,
|
||||
}
|
||||
s.setupResolver()
|
||||
|
||||
@@ -340,13 +345,18 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
}
|
||||
oldSess := s.session
|
||||
oldConn := s.conn
|
||||
oldControlStop := s.controlStop
|
||||
oldSID := s.sessionID
|
||||
s.session = newSess
|
||||
s.conn = newConn
|
||||
s.controlStop = nil
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldSess != nil {
|
||||
_ = oldSess.Close()
|
||||
}
|
||||
@@ -362,13 +372,18 @@ func (s *Server) closeSession() {
|
||||
s.sessMu.Lock()
|
||||
sess := s.session
|
||||
conn := s.conn
|
||||
controlStop := s.controlStop
|
||||
s.session = nil
|
||||
s.conn = nil
|
||||
s.controlStop = nil
|
||||
oldSID := s.sessionID
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -478,26 +493,48 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
s.sessMu.Unlock()
|
||||
s.onOpen(sid, hello.DeviceID, hello.Claims)
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.parkControlStream(stream)
|
||||
}()
|
||||
s.startControlLoop(ctx, sess, stream)
|
||||
return true
|
||||
}
|
||||
|
||||
// parkControlStream blocks reading from the control stream until it closes.
|
||||
// Future control messages (kick, rate updates, etc.) would be dispatched here.
|
||||
func (s *Server) parkControlStream(stream *smux.Stream) {
|
||||
defer func() { _ = stream.Close() }()
|
||||
buf := make([]byte, 64)
|
||||
for {
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
return
|
||||
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
s.sessMu.Lock()
|
||||
s.controlStop = stop
|
||||
s.sessMu.Unlock()
|
||||
|
||||
liveness := s.liveness
|
||||
onPong := liveness.OnPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
s.sessMu.RLock()
|
||||
sid := s.sessionID
|
||||
s.sessMu.RUnlock()
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer func() { _ = stream.Close() }()
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("server control stream ended: %v", err)
|
||||
}
|
||||
s.reinstallSession(sess)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -373,6 +374,77 @@ func TestReinstallSessionFiresOnClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
serverStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
serverStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
clientStream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
serverStream := <-serverStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
got := make(chan control.Health, 1)
|
||||
s := &Server{
|
||||
sessionID: "sid-control",
|
||||
liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
cancel()
|
||||
s.wg.Wait()
|
||||
}()
|
||||
s.startControlLoop(ctx, serverSess, serverStream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, clientStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together.
|
||||
func TestDispatchFiresOnTraffic(t *testing.T) {
|
||||
var lc net.ListenConfig
|
||||
|
||||
Reference in New Issue
Block a user