diff --git a/docs/client.example.yaml b/docs/client.example.yaml index 06b9b5e..c29fae5 100644 --- a/docs/client.example.yaml +++ b/docs/client.example.yaml @@ -30,6 +30,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + # Local SOCKS5 listener exposed to applications socks: host: "127.0.0.1" diff --git a/docs/configuration.md b/docs/configuration.md index 52123f1..07d1713 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,6 +35,8 @@ olcrtc /etc/olcrtc/server.yaml | `liveness.timeout` | pong timeout, default `5s` | | `liveness.failures` | missed pongs before reconnect, default `3` | | `lifecycle.max_session_duration` | planned session rebuild interval, e.g. `6h`; unset = off | +| `traffic.max_payload_size` | safe encrypted wire-message cap; `0` = transport default | +| `traffic.min_delay` / `.max_delay` | optional send pacing jitter, e.g. `5ms` / `30ms` | | `gen.amount` | gen mode: number of rooms to create | | `profiles[]` | ordered srv/cnc failover profiles | | `failover.retry_delay` | delay before trying the next profile, e.g. `2s` | @@ -86,6 +88,27 @@ lifecycle: The field is optional and disabled when omitted. Values use Go duration syntax such as `30m`, `2h`, or `6h`; zero and negative durations are rejected. +## Traffic Shaping + +`traffic` applies a shared reliability-oriented wrapper around the selected +transport. It can cap encrypted wire-message size and add small send pacing +delays without truncating data. When a payload would exceed the effective cap, +the send fails clearly instead of cutting bytes and corrupting smux. + +```yaml +traffic: + max_payload_size: 4096 + min_delay: 5ms + max_delay: 30ms +``` + +The wrapper clamps the configured payload cap to the selected transport's +advertised `MaxPayloadSize`. Client and server also reduce smux frame size to +fit the effective encrypted payload cap, accounting for crypto overhead. `0` +adds no extra cap beyond the selected transport's advertised limit. Delays use +Go duration syntax; if only `min_delay` is set, it is a fixed delay. Use the +same traffic settings on both peers. + ## Failover Profiles `mode: srv` and `mode: cnc` can define `profiles`. Top-level fields are used diff --git a/docs/failover.example.yaml b/docs/failover.example.yaml index 298a847..bf42482 100644 --- a/docs/failover.example.yaml +++ b/docs/failover.example.yaml @@ -19,6 +19,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + data: data profiles: diff --git a/docs/project-map.md b/docs/project-map.md index 0b09cc3..d0ebd41 100644 --- a/docs/project-map.md +++ b/docs/project-map.md @@ -73,6 +73,8 @@ Important fields: | `socks.*` | SOCKS fields | Client listener and optional server egress proxy. | | `engine.*` | direct engine fields | Used only with `auth.provider: none`. | | `liveness.*` | control liveness | Ping/pong interval, timeout, and missed-pong threshold. | +| `lifecycle.*` | session lifecycle | Planned call/session rotation. | +| `traffic.*` | send shaping | Encrypted wire-message size cap and optional pacing jitter. | `internal/app/session` is the main router: @@ -306,6 +308,7 @@ Implemented: - Profile start/end logs. - Planned session rotation with `lifecycle.max_session_duration`. - Shared supervisor status snapshots with bounded failover history. +- Shared traffic wrapper with payload cap, pacing jitter, and smux frame sizing. Still valuable: @@ -371,6 +374,8 @@ This mostly belongs in `pkg/olcrtc/tunnel` and `internal/server`. Provider APIs can drift. Worth adding: +- Central protected HTTP/WebSocket client creation with TLS 1.2+, + environment proxy support, HTTP/2 for HTTP, and bounded timeouts. - Better typed errors from auth providers. - Provider health probes. - Fixture-based contract tests for API response changes. diff --git a/docs/server.example.yaml b/docs/server.example.yaml index 300f7cf..112ce42 100644 --- a/docs/server.example.yaml +++ b/docs/server.example.yaml @@ -32,6 +32,12 @@ liveness: # lifecycle: # max_session_duration: 6h +# Optional reliability shaping for encrypted wire messages. +# traffic: +# max_payload_size: 4096 +# min_delay: 5ms +# max_delay: 30ms + # Outbound SOCKS5 proxy for server-side egress (optional) socks: proxy_addr: "" # e.g. "127.0.0.1" diff --git a/docs/settings.md b/docs/settings.md index 9f9d215..b3bf159 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -67,6 +67,8 @@ | `liveness.timeout` | Сколько ждать pong, по умолчанию `5s` | | `liveness.failures` | Сколько pong можно пропустить перед rebuild, по умолчанию `3` | | `lifecycle.max_session_duration` | Плановый rebuild сессии после указанного времени, например `6h`; если поле не задано, выключено | +| `traffic.max_payload_size` | Лимит размера зашифрованного wire-message; `0` = лимит транспорта | +| `traffic.min_delay` / `.max_delay` | Необязательный pacing отправки, например `5ms` / `30ms` | `crypto.key_file` читается относительно YAML-файла. Не указывай `crypto.key` и `crypto.key_file` одновременно. @@ -86,6 +88,13 @@ provider session. Когда таймер истекает, текущая `srv` плановый rebuild. Формат значения: `30m`, `2h`, `6h`; `0s` и отрицательные значения не принимаются. +`traffic` добавляет общий wrapper над выбранным transport. Он может ограничить +размер зашифрованного сообщения и добавить небольшую задержку перед отправкой. +Данные не обрезаются: если сообщение не помещается в эффективный лимит, send +возвращает явную ошибку. При заданном `max_payload_size` smux frame size также +уменьшается с учетом crypto overhead; при `0` остается лимит выбранного +transport. Используй одинаковые traffic-настройки на обеих сторонах. + --- ## mode: gen diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 0b48f50..8df7b65 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -15,6 +15,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/link/direct" "github.com/openlibrecommunity/olcrtc/internal/logger" @@ -137,47 +138,59 @@ var ( // ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration. ErrLifecycleMaxSessionDurationInvalid = errors.New( "invalid max session duration (set lifecycle.max_session_duration to a duration > 0)") + // ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid. + ErrTrafficMaxPayloadSizeInvalid = errors.New( + "invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)") + // ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration. + ErrTrafficMinDelayInvalid = errors.New( + "invalid traffic min delay (set traffic.min_delay to a duration >= 0)") + // ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration. + ErrTrafficMaxDelayInvalid = errors.New( + "invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)") ) // Config holds runtime session settings. type Config struct { - Mode string - Link string - Transport string - Auth string - Engine string - URL string - Token string - RoomID string - KeyHex string - SOCKSHost string - SOCKSPort int - SOCKSUser string - SOCKSPass string - DNSServer string - SOCKSProxyAddr string - SOCKSProxyPort int - VideoWidth int - VideoHeight int - VideoFPS int - VideoBitrate string - VideoHW string - VideoQRSize int - VideoQRRecovery string - VideoCodec string - VideoTileModule int - VideoTileRS int - VP8FPS int - VP8BatchSize int - SEIFPS int - SEIBatchSize int - SEIFragmentSize int - SEIAckTimeoutMS int - LivenessInterval string - LivenessTimeout string - LivenessFailures int - MaxSessionDuration string - Amount int + Mode string + Link string + Transport string + Auth string + Engine string + URL string + Token string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + SOCKSUser string + SOCKSPass string + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + LivenessInterval string + LivenessTimeout string + LivenessFailures int + MaxSessionDuration string + TrafficMaxPayloadSize int + TrafficMinDelay string + TrafficMaxDelay string + Amount int } // RegisterDefaults registers built-in carriers and transports. @@ -333,6 +346,9 @@ func Validate(cfg Config) error { if err := validateLifecycleConfig(cfg); err != nil { return err } + if err := validateTrafficConfig(cfg); err != nil { + return err + } return validateModeConfig(cfg) } @@ -539,6 +555,48 @@ func maxSessionDuration(cfg Config) (time.Duration, error) { return d, nil } +func validateTrafficConfig(cfg Config) error { + _, err := trafficConfig(cfg) + return err +} + +func trafficConfig(cfg Config) (transport.TrafficConfig, error) { + if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 && + cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) { + return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid + } + minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay) + if err != nil { + return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err) + } + maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay) + if err != nil { + return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err) + } + if maxDelay > 0 && maxDelay < minDelay { + return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid + } + return transport.TrafficConfig{ + MaxPayloadSize: cfg.TrafficMaxPayloadSize, + MinDelay: minDelay, + MaxDelay: maxDelay, + }, nil +} + +func parseOptionalNonNegativeDuration(value string) (time.Duration, error) { + if value == "" { + return 0, nil + } + d, err := time.ParseDuration(value) + if err != nil { + return 0, err + } + if d < 0 { + return 0, fmt.Errorf("duration must be >= 0") + } + return d, nil +} + func isLoopbackListenHost(host string) bool { if host == "localhost" { return true @@ -560,9 +618,13 @@ func Run(ctx context.Context, cfg Config) error { if err != nil { return err } + traffic, err := trafficConfig(cfg) + if err != nil { + return err + } run := func(ctx context.Context) error { - return runOnce(ctx, cfg, roomURL, liveness) + return runOnce(ctx, cfg, roomURL, liveness, traffic) } if maxDuration > 0 { return runWithSessionRotation(ctx, maxDuration, run) @@ -570,7 +632,13 @@ func Run(ctx context.Context, cfg Config) error { return run(ctx) } -func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.Config) error { +func runOnce( + ctx context.Context, + cfg Config, + roomURL string, + liveness control.Config, + traffic transport.TrafficConfig, +) error { switch cfg.Mode { case modeSRV: if err := server.Run(ctx, server.Config{ @@ -602,6 +670,7 @@ func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.C URL: cfg.URL, Token: cfg.Token, Liveness: liveness, + Traffic: traffic, OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) { logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims) }, @@ -646,6 +715,7 @@ func runOnce(ctx context.Context, cfg Config, roomURL string, liveness control.C URL: cfg.URL, Token: cfg.Token, Liveness: liveness, + Traffic: traffic, }); err != nil { return fmt.Errorf("client: %w", err) } diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go index 5fc219d..d75371b 100644 --- a/internal/app/session/session_test.go +++ b/internal/app/session/session_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/control" + "github.com/openlibrecommunity/olcrtc/internal/crypto" ) func TestApplyTransportDefaults(t *testing.T) { @@ -522,6 +523,62 @@ func TestValidate(t *testing.T) { }(), want: ErrLifecycleMaxSessionDurationInvalid, }, + { + name: "traffic accepts shaping", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = 4096 + cfg.TrafficMinDelay = "5ms" + cfg.TrafficMaxDelay = "30ms" + return cfg + }(), + }, + { + name: "traffic rejects negative max payload", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = -1 + return cfg + }(), + want: ErrTrafficMaxPayloadSizeInvalid, + }, + { + name: "traffic rejects payload smaller than crypto overhead", + cfg: func() Config { + cfg := base + cfg.TrafficMaxPayloadSize = crypto.WireOverhead + return cfg + }(), + want: ErrTrafficMaxPayloadSizeInvalid, + }, + { + name: "traffic rejects bad min delay", + cfg: func() Config { + cfg := base + cfg.TrafficMinDelay = "nope" + return cfg + }(), + want: ErrTrafficMinDelayInvalid, + }, + { + name: "traffic rejects negative max delay", + cfg: func() Config { + cfg := base + cfg.TrafficMaxDelay = "-1ms" + return cfg + }(), + want: ErrTrafficMaxDelayInvalid, + }, + { + name: "traffic rejects max delay below min delay", + cfg: func() Config { + cfg := base + cfg.TrafficMinDelay = "30ms" + cfg.TrafficMaxDelay = "5ms" + return cfg + }(), + want: ErrTrafficMaxDelayInvalid, + }, } for _, tt := range tests { diff --git a/internal/auth/salutejazz/api.go b/internal/auth/salutejazz/api.go index 594ac5c..40cd092 100644 --- a/internal/auth/salutejazz/api.go +++ b/internal/auth/salutejazz/api.go @@ -9,9 +9,7 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" - "strings" "github.com/google/uuid" "github.com/openlibrecommunity/olcrtc/internal/protect" @@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, statusError(errCreateRoomFailed, resp) + return nil, protect.StatusError(errCreateRoomFailed, resp, 1024) } var res createResponse @@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string defer func() { _ = preResp.Body.Close() }() if preResp.StatusCode != http.StatusOK { - return "", statusError(errPreconnectFailed, preResp) + return "", protect.StatusError(errPreconnectFailed, preResp, 1024) } var preconnectResp struct { @@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string return preconnectResp.ConnectorURL, nil } -func statusError(base error, resp *http.Response) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - bodyText := strings.TrimSpace(string(body)) - if bodyText == "" { - return fmt.Errorf("%w: status %d", base, resp.StatusCode) - } - return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText) -} - func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) { headers := anonymousHeaders() connectorURL, err := preconnect(ctx, roomID, password, headers) diff --git a/internal/auth/telemost/api.go b/internal/auth/telemost/api.go index cde00f0..a9b1116 100644 --- a/internal/auth/telemost/api.go +++ b/internal/auth/telemost/api.go @@ -11,7 +11,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" @@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body) + return nil, protect.StatusError(ErrAPI, resp, 4096) } var info ConnectionInfo diff --git a/internal/auth/wbstream/api.go b/internal/auth/wbstream/api.go index 4fc277b..ea1a927 100644 --- a/internal/auth/wbstream/api.go +++ b/internal/auth/wbstream/api.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "github.com/openlibrecommunity/olcrtc/internal/protect" @@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b) + return "", protect.StatusError(errGuestRegister, resp, 4096) } var res guestRegisterResponse @@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - b, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b) + return "", protect.StatusError(errCreateRoom, resp, 4096) } var res createRoomResponse @@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b) + return protect.StatusError(errJoinRoom, resp, 4096) } return nil } @@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b) + return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096) } var res tokenResponse diff --git a/internal/client/client.go b/internal/client/client.go index 001cb4c..2dfc153 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -103,6 +104,7 @@ type Config struct { URL string Token string Liveness control.Config + Traffic transport.TrafficConfig // DeviceID overrides the persistent client-side device identifier. Leave // empty to derive one from DeviceIDPath (or generate a random one if both @@ -216,6 +218,7 @@ func (c *Client) bringUpLink( SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -241,7 +244,7 @@ func (c *Client) bringUpLink( } c.conn = muxconn.New(ln, c.cipher) - sess, err := smux.Client(c.conn, smuxConfig()) + sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln))) if err != nil { return fmt.Errorf("smux client: %w", err) } @@ -332,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) { } // smuxConfig returns the tuned smux config used on both ends. -func smuxConfig() *smux.Config { +func smuxConfig(maxWirePayload ...int) *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 cfg.KeepAliveDisabled = true cfg.MaxFrameSize = 32768 + if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { + maxFrameSize := maxWirePayload[0] - crypto.WireOverhead + if maxFrameSize < cfg.MaxFrameSize { + cfg.MaxFrameSize = maxFrameSize + } + } cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second @@ -344,6 +353,14 @@ func smuxConfig() *smux.Config { return cfg } +func linkMaxPayload(ln link.Link) int { + provider, ok := ln.(link.FeaturesProvider) + if !ok { + return 0 + } + return provider.Features().MaxPayloadSize +} + func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool { c.reconnectMu.Lock() defer c.reconnectMu.Unlock() @@ -421,7 +438,7 @@ func (c *Client) tryReopenSession( _ = old.Close() } - sess, err := smux.Client(conn, smuxConfig()) + sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln))) if err != nil { logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err) return false diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 82d0099..40b3c22 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -49,6 +49,11 @@ func TestSmuxConfig(t *testing.T) { if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig() = %+v", cfg) } + capped := smuxConfig(4096) + if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { + t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d", + capped.MaxFrameSize, 4096-cryptopkg.WireOverhead) + } } func TestSocks5Handshake(t *testing.T) { diff --git a/internal/config/config.go b/internal/config/config.go index 770adf5..3cd5a0a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,6 +43,7 @@ type File struct { SEI SEI `yaml:"sei"` Liveness Liveness `yaml:"liveness"` Lifecycle Lifecycle `yaml:"lifecycle"` + Traffic Traffic `yaml:"traffic"` Gen Gen `yaml:"gen"` Profiles []Profile `yaml:"profiles"` Failover Failover `yaml:"failover"` @@ -66,6 +67,7 @@ type Profile struct { SEI SEI `yaml:"sei"` Liveness Liveness `yaml:"liveness"` Lifecycle Lifecycle `yaml:"lifecycle"` + Traffic Traffic `yaml:"traffic"` } // Failover controls ordered profile failover. @@ -153,6 +155,13 @@ type Lifecycle struct { MaxSessionDuration string `yaml:"max_session_duration"` } +// Traffic controls optional reliability-oriented send shaping. +type Traffic struct { + MaxPayloadSize int `yaml:"max_payload_size"` + MinDelay string `yaml:"min_delay"` + MaxDelay string `yaml:"max_delay"` +} + // Gen controls room-generation mode. type Gen struct { Amount int `yaml:"amount"` @@ -268,6 +277,9 @@ func Apply(dst session.Config, f File) session.Config { dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout) dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures) dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration) + dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize) + dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay) + dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay) dst.Amount = pickInt(dst.Amount, f.Gen.Amount) return dst } @@ -310,6 +322,9 @@ func ApplyProfile(base session.Config, p Profile) session.Config { dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout) dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures) dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration) + dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize) + dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay) + dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay) return dst } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 06d1406..c699283 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -45,6 +45,10 @@ liveness: failures: 4 lifecycle: max_session_duration: 6h +traffic: + max_payload_size: 4096 + min_delay: 5ms + max_delay: 30ms gen: amount: 3 debug: true @@ -82,24 +86,27 @@ func requireLoadedFile(t *testing.T, f File) { func requireAppliedConfig(t *testing.T, got session.Config) { t.Helper() want := session.Config{ - Mode: testModeSrv, - Link: "direct", - Auth: testAuthProvider, - RoomID: testRoomID, - KeyHex: testCryptoKey, - Transport: "datachannel", - DNSServer: "1.1.1.1:53", - SOCKSHost: "127.0.0.1", - SOCKSPort: 1080, - SOCKSUser: "u", - SOCKSPass: "p", - VP8FPS: 25, - VP8BatchSize: 4, - LivenessInterval: "2s", - LivenessTimeout: "500ms", - LivenessFailures: 4, - MaxSessionDuration: "6h", - Amount: 3, + Mode: testModeSrv, + Link: "direct", + Auth: testAuthProvider, + RoomID: testRoomID, + KeyHex: testCryptoKey, + Transport: "datachannel", + DNSServer: "1.1.1.1:53", + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + SOCKSUser: "u", + SOCKSPass: "p", + VP8FPS: 25, + VP8BatchSize: 4, + LivenessInterval: "2s", + LivenessTimeout: "500ms", + LivenessFailures: 4, + MaxSessionDuration: "6h", + TrafficMaxPayloadSize: 4096, + TrafficMinDelay: "5ms", + TrafficMaxDelay: "30ms", + Amount: 3, } if got != want { t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want) @@ -148,6 +155,10 @@ liveness: failures: 5 lifecycle: max_session_duration: 6h +traffic: + max_payload_size: 8192 + min_delay: 10ms + max_delay: 40ms profiles: - name: wb-vp8 auth: @@ -162,6 +173,9 @@ profiles: interval: 1s lifecycle: max_session_duration: 30m + traffic: + max_payload_size: 4096 + max_delay: 20ms - name: jitsi-dc auth: provider: jitsi @@ -196,7 +210,8 @@ failover: } if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 || first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 || - first.MaxSessionDuration != "30m" { + first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 || + first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" { t.Fatalf("first inherited/overlaid fields = %+v", first) } second := ApplyProfile(base, f.Profiles[1]) @@ -205,7 +220,8 @@ failover: t.Fatalf("second profile = %+v", second) } if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 || - second.MaxSessionDuration != "6h" { + second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 || + second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" { t.Fatalf("second lifecycle/liveness fields = %+v", second) } } diff --git a/internal/crypto/chacha.go b/internal/crypto/chacha.go index 686d8b8..93a8425 100644 --- a/internal/crypto/chacha.go +++ b/internal/crypto/chacha.go @@ -10,6 +10,9 @@ import ( "golang.org/x/crypto/chacha20poly1305" ) +// WireOverhead is the number of bytes added to each encrypted message. +const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead + var ( // ErrInvalidKeySize is returned when the encryption key is not 32 bytes. ErrInvalidKeySize = errors.New("invalid key size") diff --git a/internal/engine/goolom/lifecycle.go b/internal/engine/goolom/lifecycle.go index 316107f..7dd803d 100644 --- a/internal/engine/goolom/lifecycle.go +++ b/internal/engine/goolom/lifecycle.go @@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error { } func (s *Session) dialWebSocket() error { - wsDialer := websocket.Dialer{ - NetDialContext: protect.DialContext, - HandshakeTimeout: wsHandshakeTimeout, - } + wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout) ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil) if err != nil { return fmt.Errorf("dial ws: %w", err) diff --git a/internal/engine/salutejazz/salutejazz.go b/internal/engine/salutejazz/salutejazz.go index 5daf47f..b1b8903 100644 --- a/internal/engine/salutejazz/salutejazz.go +++ b/internal/engine/salutejazz/salutejazz.go @@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration) } func (s *Session) dialWebSocket() error { - wsDialer := websocket.Dialer{ - NetDialContext: protect.DialContext, - HandshakeTimeout: wsHandshakeTimeout, - } + wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout) ws, resp, err := wsDialer.Dial(s.connectorURL, nil) if err != nil { diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go index 4b2aa73..65089ab 100644 --- a/internal/link/direct/direct.go +++ b/internal/link/direct/direct.go @@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) { SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return nil, fmt.Errorf("create transport for direct link: %w", err) @@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) { d.transport.WatchConnection(ctx) } func (d *directLink) CanSend() bool { return d.transport.CanSend() } + +// Features reports the direct link's underlying transport capabilities. +func (d *directLink) Features() link.Features { return d.transport.Features() } diff --git a/internal/link/direct/direct_test.go b/internal/link/direct/direct_test.go index 18edd2e..f891e88 100644 --- a/internal/link/direct/direct_test.go +++ b/internal/link/direct/direct_test.go @@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { VideoTileRS: 20, VP8FPS: 25, VP8BatchSize: 8, + Traffic: transport.TrafficConfig{MaxPayloadSize: 4096}, }) if err != nil { t.Fatalf("New() error = %v", err) } - if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 { + if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 || + seen.Traffic.MaxPayloadSize != 4096 { t.Fatalf("forwarded config = %+v", seen) } @@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) { if !ln.CanSend() { t.Fatal("CanSend() = false, want true") } + if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 { + t.Fatalf("Features() = %+v, want shaped max payload 4096", features) + } } func TestNewWrapsFactoryError(t *testing.T) { diff --git a/internal/link/link.go b/internal/link/link.go index f094cd0..c8957ac 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -4,6 +4,8 @@ package link import ( "context" "errors" + + "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -23,11 +25,19 @@ type Link interface { CanSend() bool } +// Features mirrors the underlying transport capabilities when a link can expose them. +type Features = transport.Features + +// FeaturesProvider is optionally implemented by links that can report wire limits. +type FeaturesProvider interface { + Features() Features +} + // Config holds common link configuration. type Config struct { - Transport string - Carrier string - RoomURL string + Transport string + Carrier string + RoomURL string // Engine, URL, Token are forwarded for the "none" auth carrier. Engine string URL string @@ -54,6 +64,7 @@ type Config struct { SEIBatchSize int SEIFragmentSize int SEIAckTimeoutMS int + Traffic transport.TrafficConfig } // Factory creates a link instance. diff --git a/internal/protect/protect.go b/internal/protect/protect.go index 29bc277..2919fa3 100644 --- a/internal/protect/protect.go +++ b/internal/protect/protect.go @@ -3,13 +3,38 @@ package protect import ( "context" + "crypto/tls" "fmt" + "io" "net" "net/http" + "regexp" + "strings" "syscall" "time" + + "github.com/gorilla/websocket" ) +const ( + defaultDialTimeout = 10 * time.Second + defaultKeepAlive = 30 * time.Second + defaultIdleConnTimeout = 30 * time.Second + defaultTLSHandshake = 10 * time.Second + defaultResponseHeader = 10 * time.Second + defaultWebSocketTimeout = 10 * time.Second + defaultHTTPClientTimeout = 30 * time.Second + defaultStatusBodyLimit = 1024 +) + +var ( + sensitiveFieldRE = regexp.MustCompile( + `(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` + + `[^",\s}]+`, + ) + sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`) +) //nolint:gochecknoglobals // compiled once for provider error redaction + // Protector is called with a socket file descriptor before connect. // On Android, this calls VpnService.protect(fd) to bypass VPN routing. var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional @@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error { // NewDialer returns a net.Dialer that calls Protector on each new socket. func NewDialer() *net.Dialer { return &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, Control: controlFunc, } } +// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients. +func NewTLSConfig() *tls.Config { + return &tls.Config{MinVersion: tls.VersionTLS12} +} + +// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts. +func NewHTTPTransport() *http.Transport { + dialer := NewDialer() + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + TLSClientConfig: NewTLSConfig(), + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: defaultIdleConnTimeout, + TLSHandshakeTimeout: defaultTLSHandshake, + ResponseHeaderTimeout: defaultResponseHeader, + } +} + // NewHTTPClient returns an http.Client using protected sockets. func NewHTTPClient() *http.Client { - dialer := NewDialer() - transport := &http.Transport{ - DialContext: dialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, + return &http.Client{ + Transport: NewHTTPTransport(), + Timeout: defaultHTTPClientTimeout, } - return &http.Client{Transport: transport} +} + +// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy. +func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer { + if handshakeTimeout <= 0 { + handshakeTimeout = defaultWebSocketTimeout + } + return websocket.Dialer{ + NetDialContext: DialContext, + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: NewTLSConfig(), + HandshakeTimeout: handshakeTimeout, + } +} + +// StatusError formats an upstream HTTP error while bounding and redacting the body. +func StatusError(base error, resp *http.Response, limit int64) error { + if limit <= 0 { + limit = defaultStatusBodyLimit + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, limit)) + bodyText := RedactSensitive(strings.TrimSpace(string(body))) + if bodyText == "" { + return fmt.Errorf("%w: status %d", base, resp.StatusCode) + } + return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText) +} + +// RedactSensitive removes common token-like values from provider error text. +func RedactSensitive(text string) string { + text = sensitiveBearerRE.ReplaceAllString(text, "${1}") + return sensitiveFieldRE.ReplaceAllString(text, "${1}") } // DialContext dials using a protected socket. diff --git a/internal/protect/protect_test.go b/internal/protect/protect_test.go index 515f82d..e07a666 100644 --- a/internal/protect/protect_test.go +++ b/internal/protect/protect_test.go @@ -2,9 +2,11 @@ package protect import ( "context" + "crypto/tls" "errors" "net" "net/http" + "strings" "syscall" "testing" "time" @@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) { if !ok { t.Fatalf("Transport type = %T, want *http.Transport", client.Transport) } - if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 || + if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil || + tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 || tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second || - tr.ResponseHeaderTimeout != 10*time.Second { + tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second { t.Fatalf("transport = %+v", tr) } } +func TestNewWebSocketDialer(t *testing.T) { + dialer := NewWebSocketDialer(3 * time.Second) + if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil || + dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 || + dialer.HandshakeTimeout != 3*time.Second { + t.Fatalf("NewWebSocketDialer() = %+v", dialer) + } + + defaulted := NewWebSocketDialer(0) + if defaulted.HandshakeTimeout != defaultWebSocketTimeout { + t.Fatalf("default HandshakeTimeout = %v, want %v", + defaulted.HandshakeTimeout, defaultWebSocketTimeout) + } +} + +func TestStatusErrorRedactsAndLimitsBody(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)}, + } + err := StatusError(errProtectBoom, resp, 1024) + if err == nil { + t.Fatal("StatusError() error = nil") + } + text := err.Error() + if strings.Contains(text, "secret") || !strings.Contains(text, "") { + t.Fatalf("StatusError() = %q, want redacted token", text) + } +} + +func TestRedactSensitiveBearer(t *testing.T) { + got := RedactSensitive("Authorization: Bearer abc.def") + if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer ") { + t.Fatalf("RedactSensitive() = %q", got) + } +} + +type ioNopCloser struct { + *strings.Reader +} + +func (c ioNopCloser) Close() error { return nil } + func TestDialContextAndProxyDialer(t *testing.T) { var lc net.ListenConfig ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0") diff --git a/internal/server/server.go b/internal/server/server.go index 7dae4eb..2c28805 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,6 +21,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/muxconn" "github.com/openlibrecommunity/olcrtc/internal/names" + "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/xtaci/smux" ) @@ -116,6 +117,7 @@ type Config struct { URL string Token string Liveness control.Config + Traffic transport.TrafficConfig // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. @@ -234,11 +236,17 @@ func (s *Server) setupResolver() { // smuxConfig mirrors the client side. Both peers must agree on Version and // MaxFrameSize. -func smuxConfig() *smux.Config { +func smuxConfig(maxWirePayload ...int) *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 cfg.KeepAliveDisabled = true cfg.MaxFrameSize = 32768 + if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead { + maxFrameSize := maxWirePayload[0] - crypto.WireOverhead + if maxFrameSize < cfg.MaxFrameSize { + cfg.MaxFrameSize = maxFrameSize + } + } cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second @@ -246,6 +254,14 @@ func smuxConfig() *smux.Config { return cfg } +func linkMaxPayload(ln link.Link) int { + provider, ok := ln.(link.FeaturesProvider) + if !ok { + return 0 + } + return provider.Features().MaxPayloadSize +} + func (s *Server) bringUpLink( ctx context.Context, cfg Config, @@ -280,6 +296,7 @@ func (s *Server) bringUpLink( SEIBatchSize: cfg.SEIBatchSize, SEIFragmentSize: cfg.SEIFragmentSize, SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + Traffic: cfg.Traffic, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -316,7 +333,7 @@ func (s *Server) bringUpLink( func (s *Server) installSession() { conn := muxconn.New(s.ln, s.cipher) - sess, err := smux.Server(conn, smuxConfig()) + sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) return @@ -342,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) { // Pre-build the replacement so we can swap atomically below. newConn := muxconn.New(s.ln, s.cipher) - newSess, err := smux.Server(newConn, smuxConfig()) + newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln))) if err != nil { logger.Warnf("smux server init failed: %v", err) _ = newConn.Close() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index dc80b21..65a2bc5 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -50,6 +50,11 @@ func TestSmuxConfig(t *testing.T) { if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig() = %+v", cfg) } + capped := smuxConfig(4096) + if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead { + t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d", + capped.MaxFrameSize, 4096-cryptopkg.WireOverhead) + } } func TestParseConnectRequest(t *testing.T) { diff --git a/internal/transport/traffic.go b/internal/transport/traffic.go new file mode 100644 index 0000000..31f194b --- /dev/null +++ b/internal/transport/traffic.go @@ -0,0 +1,91 @@ +package transport + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "sync" + "time" +) + +var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size") + +type trafficTransport struct { + inner Transport + maxPayloadSize int + minDelay time.Duration + maxDelay time.Duration + sendMu sync.Mutex +} + +// WithTraffic wraps tr with optional payload caps and send pacing. +func WithTraffic(tr Transport, cfg TrafficConfig) Transport { + if tr == nil { + return nil + } + cfg = effectiveTrafficConfig(tr.Features(), cfg) + if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 { + return tr + } + return &trafficTransport{ + inner: tr, + maxPayloadSize: cfg.MaxPayloadSize, + minDelay: cfg.MinDelay, + maxDelay: cfg.MaxDelay, + } +} + +func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig { + if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize { + cfg.MaxPayloadSize = features.MaxPayloadSize + } + return cfg +} + +func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) } + +func (t *trafficTransport) Send(data []byte) error { + t.sendMu.Lock() + defer t.sendMu.Unlock() + if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize { + return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize) + } + if delay := t.nextDelay(); delay > 0 { + time.Sleep(delay) + } + return t.inner.Send(data) +} + +func (t *trafficTransport) Close() error { return t.inner.Close() } + +func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) } + +func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) } + +func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) } + +func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) } + +func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() } + +func (t *trafficTransport) Features() Features { + features := t.inner.Features() + if t.maxPayloadSize > 0 && + (features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) { + features.MaxPayloadSize = t.maxPayloadSize + } + return features +} + +func (t *trafficTransport) nextDelay() time.Duration { + if t.maxDelay <= 0 && t.minDelay <= 0 { + return 0 + } + minDelay := t.minDelay + maxDelay := t.maxDelay + if maxDelay <= minDelay { + return minDelay + } + return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter +} diff --git a/internal/transport/traffic_test.go b/internal/transport/traffic_test.go new file mode 100644 index 0000000..9f6139a --- /dev/null +++ b/internal/transport/traffic_test.go @@ -0,0 +1,67 @@ +package transport + +import ( + "context" + "errors" + "testing" + "time" +) + +type trafficStubTransport struct { + features Features + sent [][]byte +} + +func (s *trafficStubTransport) Connect(context.Context) error { return nil } +func (s *trafficStubTransport) Send(data []byte) error { + s.sent = append(s.sent, append([]byte(nil), data...)) + return nil +} +func (s *trafficStubTransport) Close() error { return nil } +func (s *trafficStubTransport) SetReconnectCallback(func()) {} +func (s *trafficStubTransport) SetShouldReconnect(func() bool) {} +func (s *trafficStubTransport) SetEndedCallback(func(string)) {} +func (s *trafficStubTransport) WatchConnection(context.Context) {} +func (s *trafficStubTransport) CanSend() bool { return true } +func (s *trafficStubTransport) Features() Features { return s.features } + +func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) { + inner := &trafficStubTransport{} + got := WithTraffic(inner, TrafficConfig{}) + if got != inner { + t.Fatalf("WithTraffic disabled returned %T, want inner", got) + } +} + +func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) { + inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}} + tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10}) + if features := tr.Features(); features.MaxPayloadSize != 5 { + t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize) + } + err := tr.Send([]byte("123456")) + if !errors.Is(err, ErrTrafficPayloadTooLarge) { + t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge) + } + if len(inner.sent) != 0 { + t.Fatalf("inner sent %d payloads, want 0", len(inner.sent)) + } + if err := tr.Send([]byte("12345")); err != nil { + t.Fatalf("Send(max sized) error = %v", err) + } + if got := string(inner.sent[0]); got != "12345" { + t.Fatalf("inner payload = %q, want 12345", got) + } +} + +func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) { + inner := &trafficStubTransport{} + tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond}) + start := time.Now() + if err := tr.Send([]byte("x")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if elapsed := time.Since(start); elapsed < 2*time.Millisecond { + t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed) + } +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 9e11240..2f37a41 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -4,6 +4,7 @@ package transport import ( "context" "errors" + "time" ) var ( @@ -32,10 +33,17 @@ type Transport interface { Features() Features } +// TrafficConfig controls optional reliability-oriented send shaping. +type TrafficConfig struct { + MaxPayloadSize int + MinDelay time.Duration + MaxDelay time.Duration +} + // Config holds common transport configuration. type Config struct { - Carrier string - RoomURL string + Carrier string + RoomURL string // Engine, URL, Token are forwarded to carrier.Config for the "none" auth // carrier (direct engine access without a service-specific auth flow). Engine string @@ -63,6 +71,7 @@ type Config struct { SEIBatchSize int SEIFragmentSize int SEIAckTimeoutMS int + Traffic TrafficConfig } // Factory creates a transport instance. @@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) { if !ok { return nil, ErrTransportNotFound } - return factory(ctx, cfg) + tr, err := factory(ctx, cfg) + if err != nil { + return nil, err + } + return WithTraffic(tr, cfg.Traffic), nil } // Available returns a list of registered transport names.