Merge pull request #58 from cyber-debug/refine/livekit-reconnect

refine livekit reconnect and liveness
This commit is contained in:
zarazaex
2026-05-16 03:47:47 +03:00
committed by GitHub
46 changed files with 4785 additions and 294 deletions

View File

@@ -5,13 +5,17 @@ import (
"context"
"errors"
"fmt"
"net"
"slices"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/auth"
"github.com/openlibrecommunity/olcrtc/internal/carrier"
"github.com/openlibrecommunity/olcrtc/internal/carrier/builtin"
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/link/direct"
"github.com/openlibrecommunity/olcrtc/internal/logger"
@@ -37,18 +41,35 @@ const (
videoCodecTile = "tile"
)
const (
defaultVideoWidth = 1920
defaultVideoHeight = 1080
defaultVideoFPS = 30
defaultVideoBitrate = "2M"
defaultVideoHW = "none"
defaultVideoQRRecovery = "low"
defaultVP8FPS = 25
defaultVP8BatchSize = 1
defaultSEIFPS = 60
defaultSEIBatchSize = 64
defaultSEIFragmentSize = 900
defaultSEIAckTimeoutMS = 2000
)
var sessionRestartDelay = 2 * time.Second
var (
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
ErrRoomIDRequired = errors.New("room ID required (set room.id)")
// ErrModeRequired indicates that mode is not one of the supported values.
ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)")
// ErrAmountRequired indicates that -amount is required for gen mode.
ErrAmountRequired = errors.New("amount required for gen mode (use -amount <n>)")
ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)")
// ErrAmountRequired indicates that gen.amount is required for gen mode.
ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)")
// ErrAuthRequired indicates that no auth provider was selected.
ErrAuthRequired = errors.New(
"auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)")
// ErrURLRequired indicates that -url must be provided when the auth provider has no default URL.
ErrURLRequired = errors.New("SFU URL required (use -url wss://...)")
"auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)")
// ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL.
ErrURLRequired = errors.New("SFU URL required (set auth.url)")
// ErrUnsupportedCarrier indicates that carrier is not registered.
ErrUnsupportedCarrier = errors.New("unsupported carrier")
// ErrUnsupportedLink indicates that link is not registered.
@@ -57,88 +78,119 @@ var (
ErrUnsupportedTransport = errors.New("unsupported transport")
// ErrLinkRequired indicates that link is not provided.
ErrLinkRequired = errors.New("link required (use -link direct)")
ErrLinkRequired = errors.New("link required (set link to direct)")
// ErrTransportRequired indicates that transport is not provided.
ErrTransportRequired = errors.New(
"transport required (use -transport datachannel, -transport videochannel, " +
"-transport seichannel or -transport vp8channel)")
"transport required (set transport to datachannel, videochannel, seichannel or vp8channel)")
// ErrKeyRequired indicates that encryption key is not provided.
ErrKeyRequired = errors.New("key required (use -key <hex>)")
ErrKeyRequired = errors.New("key required (set crypto.key)")
// ErrDNSServerRequired indicates that dns server is not provided.
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
ErrDNSServerRequired = errors.New("dns server required (set net.dns)")
// ErrVideoWidthRequired indicates that video width is required for videochannel.
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)")
// ErrVideoHeightRequired indicates that video height is required for videochannel.
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)")
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)")
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
ErrVideoBitrateRequired = errors.New(
"video bitrate required for videochannel (use -video-bitrate)")
"video bitrate required for videochannel (set video.bitrate)")
// ErrVideoHWRequired indicates that video hardware acceleration is required.
ErrVideoHWRequired = errors.New(
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
"video hardware acceleration required for videochannel (set video.hw to none or nvenc)")
// ErrVideoCodecInvalid indicates that the video codec is not valid.
ErrVideoCodecInvalid = errors.New(
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
"invalid video codec for videochannel (set video.codec to qrcode or tile)")
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080")
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)")
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)")
// ErrSEIFPSRequired indicates that seichannel fps is required.
ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)")
ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)")
// ErrSEIBatchSizeRequired indicates that seichannel batch size is required.
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)")
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)")
// ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required.
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)")
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)")
// ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required.
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)")
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)")
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)")
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)")
// ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication.
ErrSOCKSAuthRequired = errors.New(
"socks auth required when binding outside loopback (set socks.user and socks.pass)")
// ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration.
ErrLivenessIntervalInvalid = errors.New(
"invalid liveness interval (set liveness.interval to a duration > 0)")
// ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration.
ErrLivenessTimeoutInvalid = errors.New(
"invalid liveness timeout (set liveness.timeout to a duration > 0)")
// ErrLivenessFailuresInvalid indicates that liveness.failures is not positive.
ErrLivenessFailuresInvalid = errors.New(
"invalid liveness failures (set liveness.failures to a value > 0)")
// ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration.
ErrLifecycleMaxSessionDurationInvalid = errors.New(
"invalid max session duration (set lifecycle.max_session_duration to a duration > 0)")
// ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid.
ErrTrafficMaxPayloadSizeInvalid = errors.New(
"invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)")
// ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration.
ErrTrafficMinDelayInvalid = errors.New(
"invalid traffic min delay (set traffic.min_delay to a duration >= 0)")
// ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration.
ErrTrafficMaxDelayInvalid = errors.New(
"invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)")
)
// Config holds runtime session settings.
type Config struct {
Mode string
Link string
Transport string
Auth string
Engine string
URL string
Token string
RoomID string
KeyHex string
SOCKSHost string
SOCKSPort int
SOCKSUser string
SOCKSPass string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Amount int
Mode string
Link string
Transport string
Auth string
Engine string
URL string
Token string
RoomID string
KeyHex string
SOCKSHost string
SOCKSPort int
SOCKSUser string
SOCKSPass string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
LivenessInterval string
LivenessTimeout string
LivenessFailures int
MaxSessionDuration string
TrafficMaxPayloadSize int
TrafficMinDelay string
TrafficMaxDelay string
Amount int
}
// RegisterDefaults registers built-in carriers and transports.
@@ -180,6 +232,94 @@ func ApplyAuthDefaults(cfg Config) (Config, error) {
return cfg, nil
}
// ApplyTransportDefaults fills documented transport defaults without changing core routing fields.
func ApplyTransportDefaults(cfg Config) Config {
switch cfg.Transport {
case transportVideo:
return applyVideoDefaults(cfg)
case transportVP8:
return applyVP8Defaults(cfg)
case transportSEI:
return applySEIDefaults(cfg)
default:
return cfg
}
}
// ApplyLivenessDefaults fills documented control-stream liveness defaults.
func ApplyLivenessDefaults(cfg Config) Config {
if cfg.LivenessInterval == "" {
cfg.LivenessInterval = control.DefaultInterval.String()
}
if cfg.LivenessTimeout == "" {
cfg.LivenessTimeout = control.DefaultTimeout.String()
}
if cfg.LivenessFailures == 0 {
cfg.LivenessFailures = control.DefaultFailures
}
return cfg
}
func applyVideoDefaults(cfg Config) Config {
if cfg.VideoCodec == "" {
cfg.VideoCodec = videoCodecQRCode
}
if cfg.VideoCodec == videoCodecTile {
if cfg.VideoWidth == 0 {
cfg.VideoWidth = 1080
}
if cfg.VideoHeight == 0 {
cfg.VideoHeight = 1080
}
} else {
if cfg.VideoWidth == 0 {
cfg.VideoWidth = defaultVideoWidth
}
if cfg.VideoHeight == 0 {
cfg.VideoHeight = defaultVideoHeight
}
}
if cfg.VideoFPS == 0 {
cfg.VideoFPS = defaultVideoFPS
}
if cfg.VideoBitrate == "" {
cfg.VideoBitrate = defaultVideoBitrate
}
if cfg.VideoHW == "" {
cfg.VideoHW = defaultVideoHW
}
if cfg.VideoQRRecovery == "" {
cfg.VideoQRRecovery = defaultVideoQRRecovery
}
return cfg
}
func applyVP8Defaults(cfg Config) Config {
if cfg.VP8FPS == 0 {
cfg.VP8FPS = defaultVP8FPS
}
if cfg.VP8BatchSize == 0 {
cfg.VP8BatchSize = defaultVP8BatchSize
}
return cfg
}
func applySEIDefaults(cfg Config) Config {
if cfg.SEIFPS == 0 {
cfg.SEIFPS = defaultSEIFPS
}
if cfg.SEIBatchSize == 0 {
cfg.SEIBatchSize = defaultSEIBatchSize
}
if cfg.SEIFragmentSize == 0 {
cfg.SEIFragmentSize = defaultSEIFragmentSize
}
if cfg.SEIAckTimeoutMS == 0 {
cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS
}
return cfg
}
// Validate verifies that the runtime config refers to registered components and all required fields are present.
func Validate(cfg Config) error {
if err := validateMode(cfg); err != nil {
@@ -200,6 +340,15 @@ func Validate(cfg Config) error {
if err := validateTransportConfig(cfg); err != nil {
return err
}
if err := validateLivenessConfig(cfg); err != nil {
return err
}
if err := validateLifecycleConfig(cfg); err != nil {
return err
}
if err := validateTrafficConfig(cfg); err != nil {
return err
}
return validateModeConfig(cfg)
}
@@ -333,13 +482,163 @@ func validateModeConfig(cfg Config) error {
if cfg.SOCKSPort == 0 {
return ErrSOCKSPortRequired
}
if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") {
return ErrSOCKSAuthRequired
}
return nil
}
func validateLivenessConfig(cfg Config) error {
if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil {
return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
}
if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil {
return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
}
if cfg.LivenessFailures < 0 {
return ErrLivenessFailuresInvalid
}
return nil
}
func validateLifecycleConfig(cfg Config) error {
if _, err := maxSessionDuration(cfg); err != nil {
return err
}
return nil
}
func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) {
if value == "" {
return def, nil
}
d, err := time.ParseDuration(value)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be > 0")
}
return d, nil
}
func livenessConfig(cfg Config) (control.Config, error) {
interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval)
if err != nil {
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
}
timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout)
if err != nil {
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
}
failures := cfg.LivenessFailures
if failures == 0 {
failures = control.DefaultFailures
}
if failures < 0 {
return control.Config{}, ErrLivenessFailuresInvalid
}
return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil
}
func maxSessionDuration(cfg Config) (time.Duration, error) {
if cfg.MaxSessionDuration == "" {
return 0, nil
}
d, err := time.ParseDuration(cfg.MaxSessionDuration)
if err != nil {
return 0, fmt.Errorf("%w: %v", ErrLifecycleMaxSessionDurationInvalid, err)
}
if d <= 0 {
return 0, ErrLifecycleMaxSessionDurationInvalid
}
return d, nil
}
func validateTrafficConfig(cfg Config) error {
_, err := trafficConfig(cfg)
return err
}
func trafficConfig(cfg Config) (transport.TrafficConfig, error) {
if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 &&
cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) {
return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid
}
minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay)
if err != nil {
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err)
}
maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay)
if err != nil {
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err)
}
if maxDelay > 0 && maxDelay < minDelay {
return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid
}
return transport.TrafficConfig{
MaxPayloadSize: cfg.TrafficMaxPayloadSize,
MinDelay: minDelay,
MaxDelay: maxDelay,
}, nil
}
func parseOptionalNonNegativeDuration(value string) (time.Duration, error) {
if value == "" {
return 0, nil
}
d, err := time.ParseDuration(value)
if err != nil {
return 0, err
}
if d < 0 {
return 0, fmt.Errorf("duration must be >= 0")
}
return d, nil
}
func isLoopbackListenHost(host string) bool {
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
// Run starts the configured mode.
func Run(ctx context.Context, cfg Config) error {
cfg = ApplyTransportDefaults(cfg)
cfg = ApplyLivenessDefaults(cfg)
roomURL := cfg.RoomID
liveness, err := livenessConfig(cfg)
if err != nil {
return err
}
maxDuration, err := maxSessionDuration(cfg)
if err != nil {
return err
}
traffic, err := trafficConfig(cfg)
if err != nil {
return err
}
run := func(ctx context.Context) error {
return runOnce(ctx, cfg, roomURL, liveness, traffic)
}
if maxDuration > 0 {
return runWithSessionRotation(ctx, maxDuration, run)
}
return run(ctx)
}
func runOnce(
ctx context.Context,
cfg Config,
roomURL string,
liveness control.Config,
traffic transport.TrafficConfig,
) error {
switch cfg.Mode {
case modeSRV:
if err := server.Run(ctx, server.Config{
@@ -370,6 +669,8 @@ func Run(ctx context.Context, cfg Config) error {
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
Liveness: liveness,
Traffic: traffic,
OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) {
logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims)
},
@@ -413,6 +714,8 @@ func Run(ctx context.Context, cfg Config) error {
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
Liveness: liveness,
Traffic: traffic,
}); err != nil {
return fmt.Errorf("client: %w", err)
}
@@ -422,6 +725,52 @@ func Run(ctx context.Context, cfg Config) error {
}
}
func runWithSessionRotation(ctx context.Context, maxDuration time.Duration, run func(context.Context) error) error {
for cycle := 1; ; cycle++ {
currentCycle := cycle
runCtx, cancel := context.WithCancel(ctx)
var rotated atomic.Bool
timer := time.AfterFunc(maxDuration, func() {
rotated.Store(true)
logger.Infof("session max duration reached: duration=%s cycle=%d", maxDuration, currentCycle)
cancel()
})
err := run(runCtx)
cancel()
timer.Stop()
if ctx.Err() != nil {
return nil
}
if rotated.Load() {
if err != nil {
logger.Warnf("session rotation ended with error: cycle=%d err=%v", currentCycle, err)
}
logger.Infof("session rotation restarting: next_cycle=%d", currentCycle+1)
if err := waitSessionRestart(ctx); err != nil {
return nil
}
continue
}
if err != nil {
return err
}
logger.Infof("session ended cleanly with lifecycle rotation enabled: next_cycle=%d", currentCycle+1)
if err := waitSessionRestart(ctx); err != nil {
return nil
}
}
}
func waitSessionRestart(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(sessionRestartDelay):
return nil
}
}
// ValidateGen validates that the config contains enough fields to run gen mode.
func ValidateGen(cfg Config) error {
if cfg.Auth == "" {

View File

@@ -3,9 +3,136 @@ package session
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
)
func TestApplyTransportDefaults(t *testing.T) {
tests := []struct {
name string
in Config
want Config
}{
{
name: "vp8",
in: Config{Transport: transportVP8},
want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1},
},
{
name: "sei",
in: Config{Transport: transportSEI},
want: Config{
Transport: transportSEI,
SEIFPS: 60,
SEIBatchSize: 64,
SEIFragmentSize: 900,
SEIAckTimeoutMS: 2000,
},
},
{
name: "video qrcode",
in: Config{Transport: transportVideo},
want: Config{
Transport: transportVideo,
VideoWidth: 1920,
VideoHeight: 1080,
VideoFPS: 30,
VideoBitrate: "2M",
VideoHW: "none",
VideoQRRecovery: "low",
VideoCodec: videoCodecQRCode,
},
},
{
name: "video tile dimensions",
in: Config{Transport: transportVideo, VideoCodec: videoCodecTile},
want: Config{
Transport: transportVideo,
VideoWidth: 1080,
VideoHeight: 1080,
VideoFPS: 30,
VideoBitrate: "2M",
VideoHW: "none",
VideoQRRecovery: "low",
VideoCodec: videoCodecTile,
},
},
{
name: "keeps explicit values",
in: Config{
Transport: transportSEI,
SEIFPS: 10,
SEIBatchSize: 2,
SEIFragmentSize: 300,
SEIAckTimeoutMS: 1500,
},
want: Config{
Transport: transportSEI,
SEIFPS: 10,
SEIBatchSize: 2,
SEIFragmentSize: 300,
SEIAckTimeoutMS: 1500,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ApplyTransportDefaults(tt.in)
if got != tt.want {
t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want)
}
})
}
}
func TestApplyLivenessDefaults(t *testing.T) {
got := ApplyLivenessDefaults(Config{})
if got.LivenessInterval != control.DefaultInterval.String() {
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
}
if got.LivenessTimeout != control.DefaultTimeout.String() {
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
}
if got.LivenessFailures != control.DefaultFailures {
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
}
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
if got := ApplyLivenessDefaults(explicit); got != explicit {
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
}
}
func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) {
oldRestartDelay := sessionRestartDelay
sessionRestartDelay = time.Millisecond
t.Cleanup(func() { sessionRestartDelay = oldRestartDelay })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var calls atomic.Int32
err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error {
if calls.Add(1) >= 2 {
cancel()
return nil
}
<-ctx.Done()
return nil
})
if err != nil {
t.Fatalf("runWithSessionRotation() error = %v", err)
}
if got := calls.Load(); got < 2 {
t.Fatalf("run calls = %d, want at least 2", got)
}
}
//nolint:maintidx // table-driven validation test naturally has many cases
func TestValidate(t *testing.T) {
RegisterDefaults()
@@ -310,6 +437,148 @@ func TestValidate(t *testing.T) {
}(),
want: ErrSOCKSPortRequired,
},
{
name: "cnc rejects unauthenticated wildcard socks bind",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "0.0.0.0"
cfg.SOCKSPort = 1080
return cfg
}(),
want: ErrSOCKSAuthRequired,
},
{
name: "cnc allows authenticated wildcard socks bind",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "0.0.0.0"
cfg.SOCKSPort = 1080
cfg.SOCKSUser = "user"
cfg.SOCKSPass = "pass"
return cfg
}(),
},
{
name: "cnc allows localhost socks bind without auth",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "localhost"
cfg.SOCKSPort = 1080
return cfg
}(),
},
{
name: "liveness rejects bad interval",
cfg: func() Config {
cfg := base
cfg.LivenessInterval = "nope"
return cfg
}(),
want: ErrLivenessIntervalInvalid,
},
{
name: "liveness rejects zero timeout",
cfg: func() Config {
cfg := base
cfg.LivenessTimeout = "0s"
return cfg
}(),
want: ErrLivenessTimeoutInvalid,
},
{
name: "liveness rejects negative failures",
cfg: func() Config {
cfg := base
cfg.LivenessFailures = -1
return cfg
}(),
want: ErrLivenessFailuresInvalid,
},
{
name: "lifecycle accepts max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "1h"
return cfg
}(),
},
{
name: "lifecycle rejects bad max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "nope"
return cfg
}(),
want: ErrLifecycleMaxSessionDurationInvalid,
},
{
name: "lifecycle rejects zero max session duration",
cfg: func() Config {
cfg := base
cfg.MaxSessionDuration = "0s"
return cfg
}(),
want: ErrLifecycleMaxSessionDurationInvalid,
},
{
name: "traffic accepts shaping",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = 4096
cfg.TrafficMinDelay = "5ms"
cfg.TrafficMaxDelay = "30ms"
return cfg
}(),
},
{
name: "traffic rejects negative max payload",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = -1
return cfg
}(),
want: ErrTrafficMaxPayloadSizeInvalid,
},
{
name: "traffic rejects payload smaller than crypto overhead",
cfg: func() Config {
cfg := base
cfg.TrafficMaxPayloadSize = crypto.WireOverhead
return cfg
}(),
want: ErrTrafficMaxPayloadSizeInvalid,
},
{
name: "traffic rejects bad min delay",
cfg: func() Config {
cfg := base
cfg.TrafficMinDelay = "nope"
return cfg
}(),
want: ErrTrafficMinDelayInvalid,
},
{
name: "traffic rejects negative max delay",
cfg: func() Config {
cfg := base
cfg.TrafficMaxDelay = "-1ms"
return cfg
}(),
want: ErrTrafficMaxDelayInvalid,
},
{
name: "traffic rejects max delay below min delay",
cfg: func() Config {
cfg := base
cfg.TrafficMinDelay = "30ms"
cfg.TrafficMaxDelay = "5ms"
return cfg
}(),
want: ErrTrafficMaxDelayInvalid,
},
}
for _, tt := range tests {

View File

@@ -9,9 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/protect"
@@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, statusError(errCreateRoomFailed, resp)
return nil, protect.StatusError(errCreateRoomFailed, resp, 1024)
}
var res createResponse
@@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
defer func() { _ = preResp.Body.Close() }()
if preResp.StatusCode != http.StatusOK {
return "", statusError(errPreconnectFailed, preResp)
return "", protect.StatusError(errPreconnectFailed, preResp, 1024)
}
var preconnectResp struct {
@@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
return preconnectResp.ConnectorURL, nil
}
func statusError(base error, resp *http.Response) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
bodyText := strings.TrimSpace(string(body))
if bodyText == "" {
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
}
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
}
func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) {
headers := anonymousHeaders()
connectorURL, err := preconnect(ctx, roomID, password, headers)

View File

@@ -11,7 +11,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
@@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
return nil, protect.StatusError(ErrAPI, resp, 4096)
}
var info ConnectionInfo

View File

@@ -10,7 +10,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/openlibrecommunity/olcrtc/internal/protect"
@@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b)
return "", protect.StatusError(errGuestRegister, resp, 4096)
}
var res guestRegisterResponse
@@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b)
return "", protect.StatusError(errCreateRoom, resp, 4096)
}
var res createRoomResponse
@@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b)
return protect.StatusError(errJoinRoom, resp, 4096)
}
return nil
}
@@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b)
return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096)
}
var res tokenResponse

View File

@@ -17,12 +17,14 @@ import (
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
@@ -54,7 +56,12 @@ type Client struct {
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
reconnectMu sync.Mutex
healthMu sync.RWMutex
health control.Status
onHealth HealthFunc
deviceID string
sessionID string
claims map[string]any
@@ -63,6 +70,9 @@ type Client struct {
socksPass string
}
// HealthFunc is called when the client control health snapshot changes.
type HealthFunc func(control.Status)
// Config holds runtime configuration for [Run] and [RunWithReady].
type Config struct {
Link string
@@ -93,6 +103,8 @@ type Config struct {
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// DeviceID overrides the persistent client-side device identifier. Leave
// empty to derive one from DeviceIDPath (or generate a random one if both
@@ -106,6 +118,9 @@ type Config struct {
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
Claims map[string]any
// OnHealth receives liveness/reconnect status updates. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the client with the given configuration.
@@ -135,6 +150,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
dnsServer: cfg.DNSServer,
socksUser: cfg.SOCKSUser,
socksPass: cfg.SOCKSPass,
onHealth: cfg.OnHealth,
}
// shutdown is registered BEFORE bringUpLink so we always close any
@@ -202,6 +218,7 @@ func (c *Client) bringUpLink(
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
@@ -217,7 +234,9 @@ func (c *Client) bringUpLink(
if ctx.Err() != nil {
return
}
c.handleReconnect()
if !c.handleReconnect(ctx, cfg, cancel, "carrier") {
cancel()
}
})
if err := ln.Connect(ctx); err != nil {
@@ -225,7 +244,7 @@ func (c *Client) bringUpLink(
}
c.conn = muxconn.New(ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln)))
if err != nil {
return fmt.Errorf("smux client: %w", err)
}
@@ -243,14 +262,16 @@ func (c *Client) bringUpLink(
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
go ln.WatchConnection(ctx)
return nil
}
// openControlStream opens stream #1 on sess and performs the handshake.
// The stream stays open for the lifetime of the smux session — the server
// holds it parked, and it would carry future control messages.
// The stream stays open for the lifetime of the smux session and carries
// post-handshake control messages.
func openControlStream(
sess *smux.Session,
deviceID string,
@@ -314,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) {
}
// smuxConfig returns the tuned smux config used on both ends.
func smuxConfig() *smux.Config {
func smuxConfig(maxWirePayload ...int) *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
if maxFrameSize < cfg.MaxFrameSize {
cfg.MaxFrameSize = maxFrameSize
}
}
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
@@ -326,8 +353,20 @@ func smuxConfig() *smux.Config {
return cfg
}
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect - tearing down smux session")
func linkMaxPayload(ln link.Link) int {
provider, ok := ln.(link.FeaturesProvider)
if !ok {
return 0
}
return provider.Features().MaxPayloadSize
}
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool {
c.reconnectMu.Lock()
defer c.reconnectMu.Unlock()
c.recordReconnect()
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
@@ -336,14 +375,19 @@ func (c *Client) handleReconnect() {
c.sessMu.Lock()
oldControl := c.controlStrm
oldControlStop := c.controlStop
oldSess := c.session
oldConn := c.conn
c.conn = newConn
c.session = nil
c.controlStrm = nil
c.controlStop = nil
c.sessionID = ""
c.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldControl != nil {
_ = oldControl.Close()
}
@@ -364,15 +408,26 @@ func (c *Client) handleReconnect() {
attemptDelay = 300 * time.Millisecond
)
for attempt := 1; attempt <= maxAttempts; attempt++ {
if c.tryReopenSession(attempt) {
return
logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason)
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(attemptDelay):
}
time.Sleep(attemptDelay)
}
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
return false
}
func (c *Client) tryReopenSession(attempt int) bool {
func (c *Client) tryReopenSession(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
attempt int,
) bool {
conn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
@@ -383,7 +438,7 @@ func (c *Client) tryReopenSession(attempt int) bool {
_ = old.Close()
}
sess, err := smux.Client(conn, smuxConfig())
sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln)))
if err != nil {
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
@@ -400,19 +455,138 @@ func (c *Client) tryReopenSession(attempt int) bool {
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
c.recordSession(sid)
c.startControlLoop(ctx, cfg, cancel, control)
return true
}
func (c *Client) startControlLoop(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
stream *smux.Stream,
) {
controlCtx, stop := context.WithCancel(ctx)
c.sessMu.Lock()
c.controlStop = stop
c.sessMu.Unlock()
liveness := cfg.Liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
c.sessMu.RLock()
sid := c.sessionID
c.sessMu.RUnlock()
c.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
c.recordMissed(missed)
logger.Warnf("control missed pong on client: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
c.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
go func() {
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("client control stream ended: %v", err)
}
if !c.handleReconnect(ctx, cfg, cancel, "liveness") {
cancel()
}
}()
}
// Status returns the latest client-side control health snapshot.
func (c *Client) Status() control.Status {
c.healthMu.RLock()
defer c.healthMu.RUnlock()
return c.health
}
func (c *Client) recordSession(sessionID string) {
c.healthMu.Lock()
c.health.SessionID = sessionID
c.health.MissedPongs = 0
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordPong(h control.Health) {
c.healthMu.Lock()
c.health.LastPong = h.LastSeen
c.health.LastRTT = h.RTT
c.health.MissedPongs = 0
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordMissed(missed int) {
c.healthMu.Lock()
c.health.MissedPongs = missed
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordUnhealthy(missed int) {
c.healthMu.Lock()
c.health.MissedPongs = missed
c.health.UnhealthyEvents++
c.health.LastUnhealthy = time.Now()
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) recordReconnect() {
c.healthMu.Lock()
c.health.Reconnects++
status := c.health
c.healthMu.Unlock()
c.notifyHealth(status)
}
func (c *Client) notifyHealth(status control.Status) {
if c.onHealth != nil {
c.onHealth(status)
}
}
func (c *Client) shutdown() {
c.sessMu.Lock()
control := c.controlStrm
controlStop := c.controlStop
sess := c.session
conn := c.conn
c.controlStrm = nil
c.controlStop = nil
c.session = nil
c.conn = nil
c.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if conn != nil {
_ = conn.Close()
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
@@ -48,6 +49,11 @@ func TestSmuxConfig(t *testing.T) {
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
capped := smuxConfig(4096)
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
}
}
func TestSocks5Handshake(t *testing.T) {
@@ -517,3 +523,96 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
t.Fatal("shutdown() did not close link")
}
}
func TestStartControlLoopReportsPong(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
peerStreamCh := make(chan *smux.Stream, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
peerStreamCh <- stream
}
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
peerStream := <-peerStreamCh
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
got := make(chan control.Health, 1)
c := &Client{sessionID: "sid-control"}
c.recordSession("sid-control")
c.startControlLoop(ctx, Config{
Liveness: control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h control.Health) {
select {
case got <- h:
default:
}
},
},
}, cancel, stream)
go func() {
_ = control.Run(ctx, peerStream, control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
})
}()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for control pong")
}
status := c.Status()
if status.SessionID != "sid-control" {
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
}
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
t.Fatalf("Status() = %+v", status)
}
}
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
updates := 0
c := &Client{onHealth: func(control.Status) { updates++ }}
c.recordSession("sid-1")
c.recordMissed(2)
c.recordUnhealthy(3)
c.recordReconnect()
status := c.Status()
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
t.Fatalf("Status() = %+v", status)
}
if updates != 4 {
t.Fatalf("health updates = %d, want 4", updates)
}
}

View File

@@ -1,10 +1,9 @@
// Package config loads olcrtc runtime configuration from YAML files.
//
// The YAML schema mirrors [session.Config]. Fields left unset in the file
// remain at their zero value, allowing CLI flags to fill them in. Use
// [Apply] to merge a parsed [File] onto an existing [session.Config];
// non-zero fields in the session config (typically populated from CLI flags)
// take precedence over the YAML values.
// remain at their zero value. Use [Apply] to map a parsed [File] onto an
// existing [session.Config]; non-zero fields in the session config take
// precedence over the YAML values.
//
//nolint:tagliatelle // YAML keys are the documented config file schema.
package config
@@ -13,31 +12,68 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"gopkg.in/yaml.v3"
)
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
var ErrConfigNotFound = errors.New("config file not found")
var (
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
ErrConfigNotFound = errors.New("config file not found")
// ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured.
ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set")
// ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file.
ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty")
)
// File is the on-disk YAML schema.
type File struct {
Mode string `yaml:"mode"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Gen Gen `yaml:"gen"`
Data string `yaml:"data"`
Debug bool `yaml:"debug"`
FFmpeg string `yaml:"ffmpeg"`
Mode string `yaml:"mode"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Liveness Liveness `yaml:"liveness"`
Lifecycle Lifecycle `yaml:"lifecycle"`
Traffic Traffic `yaml:"traffic"`
Gen Gen `yaml:"gen"`
Profiles []Profile `yaml:"profiles"`
Failover Failover `yaml:"failover"`
Data string `yaml:"data"`
Debug bool `yaml:"debug"`
FFmpeg string `yaml:"ffmpeg"`
}
// Profile is a failover entry that overrides top-level runtime fields.
type Profile struct {
Name string `yaml:"name"`
Link string `yaml:"link"`
Auth Auth `yaml:"auth"`
Room Room `yaml:"room"`
Crypto Crypto `yaml:"crypto"`
Net Net `yaml:"net"`
SOCKS SOCKS `yaml:"socks"`
Engine Engine `yaml:"engine"`
Video Video `yaml:"video"`
VP8 VP8 `yaml:"vp8"`
SEI SEI `yaml:"sei"`
Liveness Liveness `yaml:"liveness"`
Lifecycle Lifecycle `yaml:"lifecycle"`
Traffic Traffic `yaml:"traffic"`
}
// Failover controls ordered profile failover.
type Failover struct {
RetryDelay string `yaml:"retry_delay"`
MaxCycles int `yaml:"max_cycles"`
}
// Auth selects the auth provider.
@@ -52,7 +88,8 @@ type Room struct {
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
type Crypto struct {
Key string `yaml:"key"` // 64-char hex (32 bytes)
Key string `yaml:"key"` // 64-char hex (32 bytes)
KeyFile string `yaml:"key_file"` // path to a file containing crypto.key
}
// Net groups network and transport selection.
@@ -106,6 +143,25 @@ type SEI struct {
AckTimeoutMS int `yaml:"ack_timeout_ms"`
}
// Liveness tunes the post-handshake control stream ping/pong checks.
type Liveness struct {
Interval string `yaml:"interval"`
Timeout string `yaml:"timeout"`
Failures int `yaml:"failures"`
}
// Lifecycle controls planned session rebuilds.
type Lifecycle struct {
MaxSessionDuration string `yaml:"max_session_duration"`
}
// Traffic controls optional reliability-oriented send shaping.
type Traffic struct {
MaxPayloadSize int `yaml:"max_payload_size"`
MinDelay string `yaml:"min_delay"`
MaxDelay string `yaml:"max_delay"`
}
// Gen controls room-generation mode.
type Gen struct {
Amount int `yaml:"amount"`
@@ -125,9 +181,63 @@ func Load(path string) (File, error) {
if err := yaml.Unmarshal(data, &f); err != nil {
return File{}, fmt.Errorf("parse config %s: %w", path, err)
}
if err := loadExternalSecrets(path, &f); err != nil {
return File{}, err
}
return f, nil
}
func loadExternalSecrets(configPath string, f *File) error {
if f.Crypto.KeyFile == "" {
return loadProfileSecrets(configPath, f.Profiles)
}
if f.Crypto.Key != "" {
return ErrCryptoKeyConflict
}
key, err := readKeyFile(configPath, f.Crypto.KeyFile)
if err != nil {
return err
}
f.Crypto.Key = key
return loadProfileSecrets(configPath, f.Profiles)
}
func loadProfileSecrets(configPath string, profiles []Profile) error {
for i := range profiles {
if profiles[i].Crypto.KeyFile == "" {
continue
}
if profiles[i].Crypto.Key != "" {
return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict)
}
key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile)
if err != nil {
return fmt.Errorf("profiles[%d]: %w", i, err)
}
profiles[i].Crypto.Key = key
}
return nil
}
func readKeyFile(configPath, keyFile string) (string, error) {
keyPath := keyFile
if !filepath.IsAbs(keyPath) {
keyPath = filepath.Join(filepath.Dir(configPath), keyPath)
}
// #nosec G304 -- key_file is an explicit path in the user's config file.
data, err := os.ReadFile(keyPath)
if err != nil {
return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err)
}
key := strings.TrimSpace(string(data))
if key == "" {
return "", ErrCryptoKeyFileEmpty
}
return key, nil
}
// Apply merges f onto dst. CLI-set fields (non-zero values in dst) win;
// YAML values fill in the rest.
func Apply(dst session.Config, f File) session.Config {
@@ -163,10 +273,61 @@ func Apply(dst session.Config, f File) session.Config {
dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize)
dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize)
dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS)
dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval)
dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout)
dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures)
dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration)
dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize)
dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay)
dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay)
dst.Amount = pickInt(dst.Amount, f.Gen.Amount)
return dst
}
// ApplyProfile overlays a failover profile onto an already-applied base config.
func ApplyProfile(base session.Config, p Profile) session.Config {
dst := base
dst.Link = overlayString(dst.Link, p.Link)
dst.Transport = overlayString(dst.Transport, p.Net.Transport)
dst.Auth = overlayString(dst.Auth, p.Auth.Provider)
dst.Engine = overlayString(dst.Engine, p.Engine.Name)
dst.URL = overlayString(dst.URL, p.Engine.URL)
dst.Token = overlayString(dst.Token, p.Engine.Token)
dst.RoomID = overlayString(dst.RoomID, p.Room.ID)
dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key)
dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host)
dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port)
dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User)
dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass)
dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS)
dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr)
dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort)
dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width)
dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height)
dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS)
dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate)
dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW)
dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize)
dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery)
dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec)
dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule)
dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS)
dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS)
dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize)
dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS)
dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize)
dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize)
dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS)
dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval)
dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout)
dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures)
dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration)
dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize)
dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay)
dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay)
return dst
}
func pickString(cli, yamlVal string) string {
if cli != "" {
return cli
@@ -180,3 +341,17 @@ func pickInt(cli, yamlVal int) int {
}
return yamlVal
}
func overlayString(base, override string) string {
if override != "" {
return override
}
return base
}
func overlayInt(base, override int) int {
if override != 0 {
return override
}
return base
}

View File

@@ -1,6 +1,7 @@
package config
import (
"errors"
"os"
"path/filepath"
"testing"
@@ -38,6 +39,16 @@ socks:
vp8:
fps: 25
batch_size: 4
liveness:
interval: 2s
timeout: 500ms
failures: 4
lifecycle:
max_session_duration: 6h
traffic:
max_payload_size: 4096
min_delay: 5ms
max_delay: 30ms
gen:
amount: 3
debug: true
@@ -75,20 +86,27 @@ func requireLoadedFile(t *testing.T, f File) {
func requireAppliedConfig(t *testing.T, got session.Config) {
t.Helper()
want := session.Config{
Mode: testModeSrv,
Link: "direct",
Auth: testAuthProvider,
RoomID: testRoomID,
KeyHex: testCryptoKey,
Transport: "datachannel",
DNSServer: "1.1.1.1:53",
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
SOCKSUser: "u",
SOCKSPass: "p",
VP8FPS: 25,
VP8BatchSize: 4,
Amount: 3,
Mode: testModeSrv,
Link: "direct",
Auth: testAuthProvider,
RoomID: testRoomID,
KeyHex: testCryptoKey,
Transport: "datachannel",
DNSServer: "1.1.1.1:53",
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
SOCKSUser: "u",
SOCKSPass: "p",
VP8FPS: 25,
VP8BatchSize: 4,
LivenessInterval: "2s",
LivenessTimeout: "500ms",
LivenessFailures: 4,
MaxSessionDuration: "6h",
TrafficMaxPayloadSize: 4096,
TrafficMinDelay: "5ms",
TrafficMaxDelay: "30ms",
Amount: 3,
}
if got != want {
t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want)
@@ -121,6 +139,182 @@ func TestApplyCLIWins(t *testing.T) {
}
}
func TestLoadAndApplyProfile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "olcrtc.yaml")
body := `
mode: srv
link: direct
crypto:
key: shared-key
net:
dns: 1.1.1.1:53
liveness:
interval: 5s
timeout: 2s
failures: 5
lifecycle:
max_session_duration: 6h
traffic:
max_payload_size: 8192
min_delay: 10ms
max_delay: 40ms
profiles:
- name: wb-vp8
auth:
provider: wbstream
room:
id: wb-room
net:
transport: vp8channel
vp8:
fps: 30
liveness:
interval: 1s
lifecycle:
max_session_duration: 30m
traffic:
max_payload_size: 4096
max_delay: 20ms
- name: jitsi-dc
auth:
provider: jitsi
room:
id: https://meet.example/room
net:
transport: datachannel
dns: 8.8.8.8:53
failover:
retry_delay: 100ms
max_cycles: 2
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if len(f.Profiles) != 2 {
t.Fatalf("profiles = %d, want 2", len(f.Profiles))
}
if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 {
t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover)
}
base := Apply(session.Config{}, f)
first := ApplyProfile(base, f.Profiles[0])
if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" {
t.Fatalf("first profile = %+v", first)
}
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 ||
first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 ||
first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 ||
first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" {
t.Fatalf("first inherited/overlaid fields = %+v", first)
}
second := ApplyProfile(base, f.Profiles[1])
if second.Auth != "jitsi" || second.Transport != "datachannel" ||
second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" {
t.Fatalf("second profile = %+v", second)
}
if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 ||
second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 ||
second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" {
t.Fatalf("second lifecycle/liveness fields = %+v", second)
}
}
func TestLoadProfileCryptoKeyFile(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
profiles:
- name: file-key
crypto:
key_file: profile.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if got := f.Profiles[0].Crypto.Key; got != testCryptoKey {
t.Fatalf("profile key = %q, want %q", got, testCryptoKey)
}
}
func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "secret.key")
if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
mode: srv
crypto:
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
f, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if f.Crypto.Key != testCryptoKey {
t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey)
}
}
func TestLoadCryptoKeyFileConflict(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "olcrtc.yaml")
body := `
crypto:
key: deadbeef
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if !errors.Is(err, ErrCryptoKeyConflict) {
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict)
}
}
func TestLoadCryptoKeyFileEmpty(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "secret.key")
if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil {
t.Fatalf("write key: %v", err)
}
path := filepath.Join(dir, "olcrtc.yaml")
body := `
crypto:
key_file: secret.key
`
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if !errors.Is(err, ErrCryptoKeyFileEmpty) {
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty)
}
}
func TestLoadMissing(t *testing.T) {
_, err := Load(filepath.Join(t.TempDir(), "nope.yaml"))
if err == nil {

343
internal/control/control.go Normal file
View File

@@ -0,0 +1,343 @@
// Package control implements the post-handshake control stream protocol.
//
// The control stream is the first smux stream after the olcrtc handshake. It
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
// tunnel path still round-trips, not merely that the provider connection is up.
//
// Wire format matches the handshake framing: a 4-byte big-endian length
// followed by a JSON message.
//
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
package control
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"sync"
"time"
)
const (
// ProtoVersion identifies the control stream wire format.
ProtoVersion = 1
// MaxMessageSize caps one control frame.
MaxMessageSize = 16 * 1024
// DefaultInterval is the default interval between ping probes.
DefaultInterval = 10 * time.Second
// DefaultTimeout is the default time to wait for a pong.
DefaultTimeout = 5 * time.Second
// DefaultFailures is the default number of consecutive missed pongs before
// the stream is marked unhealthy.
DefaultFailures = 3
)
// MsgType labels a control message.
type MsgType string
const (
// TypePing is sent periodically to prove control-stream liveness.
TypePing MsgType = "CONTROL_PING"
// TypePong replies to a ping with the same sequence and timestamp.
TypePong MsgType = "CONTROL_PONG"
)
var (
// ErrUnhealthy is returned when the stream misses too many pong replies.
ErrUnhealthy = errors.New("control stream unhealthy")
// ErrProtocolVersion is returned when the peer announces an incompatible version.
ErrProtocolVersion = errors.New("incompatible control protocol version")
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
ErrUnexpectedMessage = errors.New("unexpected control message")
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
ErrFrameTooLarge = errors.New("control frame too large")
)
// Message is one control-stream frame.
type Message struct {
Version int `json:"version"`
Type MsgType `json:"type"`
Seq uint64 `json:"seq,omitempty"`
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
}
// Health is reported when a ping round trip completes.
type Health struct {
Seq uint64
RTT time.Duration
LastSeen time.Time
}
// Status is a point-in-time view of control-stream health maintained by
// callers that embed the control loop.
type Status struct {
SessionID string
LastPong time.Time
LastRTT time.Duration
MissedPongs int
Reconnects uint64
UnhealthyEvents uint64
LastUnhealthy time.Time
}
// Config controls the liveness loop.
type Config struct {
Interval time.Duration
Timeout time.Duration
Failures int
// OnPong is called after a matching pong is received.
OnPong func(Health)
// OnMissedPong is called when one or more outstanding pongs time out.
OnMissedPong func(missed int)
// OnUnhealthy is called before Run returns [ErrUnhealthy].
OnUnhealthy func(missed int)
}
func (cfg Config) withDefaults() Config {
if cfg.Interval <= 0 {
cfg.Interval = DefaultInterval
}
if cfg.Timeout <= 0 {
cfg.Timeout = DefaultTimeout
}
if cfg.Failures <= 0 {
cfg.Failures = DefaultFailures
}
return cfg
}
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
// or the configured failure threshold is reached.
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
cfg = cfg.withDefaults()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
state := &state{
rw: rw,
cfg: cfg,
pending: make(map[uint64]time.Time),
now: time.Now,
out: make(chan Message, 16),
}
errCh := make(chan error, 3)
go func() {
<-ctx.Done()
_ = rw.Close()
}()
go func() { errCh <- state.readLoop(ctx) }()
go func() { errCh <- state.probeLoop(ctx) }()
go func() { errCh <- state.writeLoop(ctx) }()
err := <-errCh
cancel()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
return err
}
type state struct {
rw io.ReadWriteCloser
cfg Config
now func() time.Time
out chan Message
mu sync.Mutex
pending map[uint64]time.Time
nextSeq uint64
failures int
}
func (s *state) readLoop(ctx context.Context) error {
for {
raw, err := readFrame(s.rw)
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
msg, err := parseMessage(raw)
if err != nil {
return err
}
switch msg.Type {
case TypePing:
if err := s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePong,
Seq: msg.Seq,
SentUnixNano: msg.SentUnixNano,
}); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
case TypePong:
s.handlePong(msg)
default:
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
}
}
func (s *state) probeLoop(ctx context.Context) error {
ticker := time.NewTicker(s.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if err := s.sendProbe(ctx); err != nil {
return err
}
}
}
}
func (s *state) sendProbe(ctx context.Context) error {
now := s.now()
s.mu.Lock()
missedNow := 0
for seq, sent := range s.pending {
if now.Sub(sent) < s.cfg.Timeout {
continue
}
delete(s.pending, seq)
s.failures++
missedNow++
}
missed := s.failures
if s.failures >= s.cfg.Failures {
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
if s.cfg.OnUnhealthy != nil {
s.cfg.OnUnhealthy(missed)
}
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
}
s.nextSeq++
seq := s.nextSeq
s.pending[seq] = now
s.mu.Unlock()
if missedNow > 0 && s.cfg.OnMissedPong != nil {
s.cfg.OnMissedPong(missed)
}
return s.enqueue(ctx, Message{
Version: ProtoVersion,
Type: TypePing,
Seq: seq,
SentUnixNano: now.UnixNano(),
})
}
func (s *state) handlePong(msg Message) {
now := s.now()
s.mu.Lock()
sent, ok := s.pending[msg.Seq]
if ok {
delete(s.pending, msg.Seq)
s.failures = 0
}
s.mu.Unlock()
if !ok || s.cfg.OnPong == nil {
return
}
s.cfg.OnPong(Health{
Seq: msg.Seq,
RTT: now.Sub(sent),
LastSeen: now,
})
}
func (s *state) enqueue(ctx context.Context, msg Message) error {
select {
case <-ctx.Done():
return ctx.Err()
case s.out <- msg:
return nil
}
}
func (s *state) writeLoop(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-s.out:
if err := writeFrame(s.rw, msg); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
}
}
}
func parseMessage(raw []byte) (Message, error) {
var msg Message
if err := json.Unmarshal(raw, &msg); err != nil {
return Message{}, fmt.Errorf("parse control message: %w", err)
}
if msg.Version != ProtoVersion {
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
ErrProtocolVersion, msg.Version, ProtoVersion)
}
if msg.Type != TypePing && msg.Type != TypePong {
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
return msg, nil
}
func writeFrame(w io.Writer, msg Message) error {
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal control message: %w", err)
}
if len(body) > MaxMessageSize {
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
}
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
if _, err := w.Write(hdr[:]); err != nil {
return fmt.Errorf("write control hdr: %w", err)
}
if _, err := w.Write(body); err != nil {
return fmt.Errorf("write control body: %w", err)
}
return nil
}
func readFrame(r io.Reader) ([]byte, error) {
var hdr [4]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return nil, fmt.Errorf("read control hdr: %w", err)
}
n := binary.BigEndian.Uint32(hdr[:])
if n > MaxMessageSize {
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return nil, fmt.Errorf("read control body: %w", err)
}
return buf, nil
}

View File

@@ -0,0 +1,138 @@
package control
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"testing"
"time"
)
func controlPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
a, b := net.Pipe()
t.Cleanup(func() {
_ = a.Close()
_ = b.Close()
})
return a, b
}
func TestRunPingPongReportsRTT(t *testing.T) {
a, b := controlPair(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
got := make(chan Health, 1)
cfg := Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h Health) {
select {
case got <- h:
default:
}
},
}
errCh := make(chan error, 2)
go func() { errCh <- Run(ctx, a, cfg) }()
go func() { errCh <- Run(ctx, b, cfg) }()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
if h.RTT < 0 {
t.Fatalf("Health.RTT = %v", h.RTT)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for pong health")
}
cancel()
for range 2 {
if err := <-errCh; err != nil {
t.Fatalf("Run() after cancel = %v", err)
}
}
}
func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) {
a, b := controlPair(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_, _ = io.Copy(io.Discard, b)
}()
missedCh := make(chan int, 1)
missedCallbackCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
errCh <- Run(ctx, a, Config{
Interval: 10 * time.Millisecond,
Timeout: 5 * time.Millisecond,
Failures: 2,
OnMissedPong: func(missed int) {
select {
case missedCallbackCh <- missed:
default:
}
},
OnUnhealthy: func(missed int) { missedCh <- missed },
})
}()
select {
case err := <-errCh:
if !errors.Is(err, ErrUnhealthy) {
t.Fatalf("Run() error = %v, want ErrUnhealthy", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for unhealthy result")
}
if missed := <-missedCh; missed < 2 {
t.Fatalf("missed = %d, want >= 2", missed)
}
if missed := <-missedCallbackCh; missed < 1 {
t.Fatalf("missed callback = %d, want >= 1", missed)
}
}
func TestRunRejectsBadProtocolVersion(t *testing.T) {
a, b := controlPair(t)
errCh := make(chan error, 1)
go func() {
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
}()
if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil {
t.Fatalf("writeFrame() error = %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, ErrProtocolVersion) {
t.Fatalf("Run() error = %v, want ErrProtocolVersion", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for protocol error")
}
}
func TestReadFrameRejectsTooLarge(t *testing.T) {
a, b := controlPair(t)
go func() {
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1)
_, _ = b.Write(hdr[:])
}()
_, err := readFrame(a)
if !errors.Is(err, ErrFrameTooLarge) {
t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err)
}
}

View File

@@ -10,6 +10,9 @@ import (
"golang.org/x/crypto/chacha20poly1305"
)
// WireOverhead is the number of bytes added to each encrypted message.
const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead
var (
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
ErrInvalidKeySize = errors.New("invalid key size")

View File

@@ -24,6 +24,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/client"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/server"
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
"github.com/pion/webrtc/v4"
)
@@ -47,6 +48,7 @@ var (
errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply")
errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting")
errPayloadMismatchOffset = errors.New("payload mismatch at offset")
errFailoverCarrier = errors.New("intentional failover carrier failure")
)
var (
@@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) {
})
}
func registerFailingCarrier(t *testing.T) string {
t.Helper()
session.RegisterDefaults()
name := "e2e-fail-" + t.Name()
carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) {
return nil, errFailoverCarrier
})
return name
}
func builtInCarrierNames() []string {
return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional
}
@@ -1008,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
if err := ln.Connect(context.Background()); err != nil {
t.Fatalf("Connect() error = %v", err)
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
assertLinkCanSendAfterConnect(t, ln, transportName)
if err := ln.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
@@ -1020,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
}
}
func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) {
t.Helper()
if transportName == transportSEI {
if ln.CanSend() {
t.Fatal("CanSend() = true before peer seichannel frame")
}
return
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
}
//nolint:cyclop // table-driven test naturally has many branches
func TestRealProviderTransportMatrix(t *testing.T) {
if !*realE2E {
@@ -1163,6 +1188,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
}
}
func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) {
echoAddr := startEchoServer(t)
failingCarrier := registerFailingCarrier(t)
memoryCarrier, room := registerMemoryCarrier(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
socksAddr := freeLocalAddr(ctx, t)
socksHost, socksPort := splitHostPort(t, socksAddr)
serverProfiles := []supervisor.Profile{
{Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)},
{Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)},
}
clientProfiles := []supervisor.Profile{
{Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)},
{Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)},
}
started := make(chan string, 8)
serverErr := make(chan error, 1)
go func() {
serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run)
}()
room.waitConnected(t, 1)
ready := make(chan struct{})
var readyOnce sync.Once
clientErr := make(chan error, 1)
go func() {
clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error {
return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() {
if cfg.Auth == memoryCarrier {
readyOnce.Do(func() { close(ready) })
}
})
})
}()
waitForReady(t, ready)
conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr)
defer func() { _ = conn.Close() }()
payload := []byte("olcrtc-failover-e2e\n")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write failover payload: %v", err)
}
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
t.Fatalf("set failover read deadline: %v", err)
}
line, err := bufio.NewReader(conn).ReadBytes('\n')
if err != nil {
t.Fatalf("read failover echo: %v", err)
}
if !bytes.Equal(line, payload) {
t.Fatalf("failover echo = %q, want %q", line, payload)
}
requireStartedProfiles(t, started, []string{
"server:failing-server",
"server:memory-server",
"client:failing-client",
"client:memory-client",
})
cancel()
waitSupervisorStopped(t, "client", clientErr)
waitSupervisorStopped(t, "server", serverErr)
}
func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config {
cfg := session.Config{
Mode: mode,
Link: linkDirect,
Transport: transportData,
Auth: carrierName,
RoomID: testRoom,
KeyHex: testKeyHex,
DNSServer: localDNSServer,
}
if mode == "cnc" {
cfg.SOCKSHost = socksHost
cfg.SOCKSPort = socksPort
}
return cfg
}
func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config {
return client.Config{
Link: cfg.Link,
Transport: cfg.Transport,
Carrier: cfg.Auth,
RoomURL: cfg.RoomID,
KeyHex: cfg.KeyHex,
LocalAddr: socksAddr,
DNSServer: cfg.DNSServer,
DeviceID: testClientDeviceID,
VideoWidth: cfg.VideoWidth,
VideoHeight: cfg.VideoHeight,
VideoFPS: cfg.VideoFPS,
VideoBitrate: cfg.VideoBitrate,
VideoHW: cfg.VideoHW,
VideoQRSize: cfg.VideoQRSize,
VideoQRRecovery: cfg.VideoQRRecovery,
VideoCodec: cfg.VideoCodec,
VideoTileModule: cfg.VideoTileModule,
VideoTileRS: cfg.VideoTileRS,
VP8FPS: cfg.VP8FPS,
VP8BatchSize: cfg.VP8BatchSize,
SEIFPS: cfg.SEIFPS,
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
}
}
func failoverE2EConfig(
profiles []supervisor.Profile,
started chan<- string,
side string,
) supervisor.Config {
return supervisor.Config{
Profiles: profiles,
RetryDelay: time.Millisecond,
OnProfileStart: func(profile supervisor.Profile, _ int) {
select {
case started <- side + ":" + profile.Name:
default:
}
},
}
}
func splitHostPort(t *testing.T, addr string) (string, int) {
t.Helper()
host, portText, err := net.SplitHostPort(addr)
if err != nil {
t.Fatalf("split host port %q: %v", addr, err)
}
port, err := strconv.Atoi(portText)
if err != nil {
t.Fatalf("parse port %q: %v", portText, err)
}
return host, port
}
func requireStartedProfiles(t *testing.T, started <-chan string, want []string) {
t.Helper()
seen := make(map[string]bool)
deadline := time.After(3 * time.Second)
for len(seen) < len(want) {
select {
case item := <-started:
seen[item] = true
case <-deadline:
t.Fatalf("started profiles = %v, want all %v", seen, want)
}
}
for _, item := range want {
if !seen[item] {
t.Fatalf("started profiles = %v, missing %s", seen, item)
}
}
}
func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) {
t.Helper()
select {
case err := <-ch:
if err != nil {
t.Fatalf("%s supervisor returned error: %v", name, err)
}
case <-time.After(3 * time.Second):
t.Fatalf("%s supervisor did not stop", name)
}
}
func TestEndedCallbackStopsClientAndServer(t *testing.T) {
rt := startTunnel(t)
rt.room.triggerEnded("conference ended")

View File

@@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error {
}
func (s *Session) dialWebSocket() error {
wsDialer := websocket.Dialer{
NetDialContext: protect.DialContext,
HandshakeTimeout: wsHandshakeTimeout,
}
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil)
if err != nil {
return fmt.Errorf("dial ws: %w", err)

View File

@@ -19,13 +19,17 @@ import (
protoLogger "github.com/livekit/protocol/logger"
lksdk "github.com/livekit/server-sdk-go/v2"
"github.com/openlibrecommunity/olcrtc/internal/engine"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/pion/webrtc/v4"
)
const (
defaultSendQueueSize = 5000
dataPublishTopic = "olcrtc"
videoTrackName = "videochannel"
defaultSendQueueSize = 5000
defaultSendQueueCapHard = 4000
dataPublishTopic = "olcrtc"
videoTrackName = "videochannel"
reconnectWindow = 5 * time.Minute
maxReconnects = 10
)
var (
@@ -41,20 +45,98 @@ var (
ErrTokenRequired = errors.New("livekit access token required")
)
type roomHandle interface {
publishData([]byte) error
publishTrack(webrtc.TrackLocal) error
unpublishLocalTracks()
disconnect()
connectionState() lksdk.ConnectionState
}
type sdkRoom struct {
room *lksdk.Room
}
func (r *sdkRoom) publishData(data []byte) error {
return r.room.LocalParticipant.PublishDataPacket(
lksdk.UserData(data),
lksdk.WithDataPublishTopic(dataPublishTopic),
lksdk.WithDataPublishReliable(true),
)
}
func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error {
_, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName})
return err
}
func (r *sdkRoom) unpublishLocalTracks() {
if r.room == nil || r.room.LocalParticipant == nil {
return
}
for _, publication := range r.room.LocalParticipant.TrackPublications() {
if publication.SID() == "" {
continue
}
if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
log.Printf("livekit unpublish track error: %v", err)
}
}
}
func (r *sdkRoom) disconnect() {
r.room.Disconnect()
// LiveKit's Disconnect returns after local SDK teardown, before the
// server necessarily evicts the participant. Give the signalling path a
// short grace period so immediate reconnects do not inherit stale room
// state from a ghost participant.
time.Sleep(2 * time.Second)
}
func (r *sdkRoom) connectionState() lksdk.ConnectionState {
return r.room.ConnectionState()
}
type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error)
func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) {
room, err := lksdk.ConnectToRoomWithToken(
url,
token,
callback,
lksdk.WithAutoSubscribe(true),
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
)
if err != nil {
return nil, err
}
return &sdkRoom{room: room}, nil
}
// Session is the LiveKit engine handle.
type Session struct {
url string
token string
name string
room *lksdk.Room
refresh func(ctx context.Context) (engine.Credentials, error)
connectRoom connectRoomFunc
room roomHandle
roomMu sync.RWMutex
onData func([]byte)
onReconnect func(*webrtc.DataChannel)
shouldReconnect func() bool
onEnded func(string)
reconnectCh chan struct{}
closeCh chan struct{}
lastReconnect time.Time
reconnectCount int
sendQueue chan []byte
closed atomic.Bool
reconnecting atomic.Bool
done chan struct{}
cancel context.CancelFunc
shutdownOnce sync.Once
sendWorkerOnce sync.Once
videoTrackMu sync.RWMutex
videoTracks []webrtc.TrackLocal
onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
@@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) {
}
_, cancel := context.WithCancel(ctx)
return &Session{
url: cfg.URL,
token: cfg.Token,
name: cfg.Name,
onData: cfg.OnData,
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
cancel: cancel,
url: cfg.URL,
token: cfg.Token,
name: cfg.Name,
refresh: cfg.Refresh,
connectRoom: connectSDKRoom,
onData: cfg.OnData,
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
cancel: cancel,
}, nil
}
@@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities {
}
// Connect joins the LiveKit room.
func (s *Session) Connect(_ context.Context) error {
func (s *Session) Connect(ctx context.Context) error {
s.closed.Store(false)
if err := s.connectSession(ctx); err != nil {
return err
}
s.startSendWorker()
return nil
}
func (s *Session) connectSession(_ context.Context) error {
roomCB := &lksdk.RoomCallback{
ParticipantCallback: lksdk.ParticipantCallback{
OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) {
@@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error {
},
},
OnDisconnected: func() {
if !s.closed.Load() && s.onEnded != nil {
s.onEnded("disconnected from livekit")
if s.closed.Load() || s.reconnecting.Load() {
return
}
if !s.queueReconnect() {
s.signalEnded("disconnected from livekit")
}
},
}
room, err := lksdk.ConnectToRoomWithToken(
s.url,
s.token,
roomCB,
lksdk.WithAutoSubscribe(true),
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
)
room, err := s.connectRoom(s.url, s.token, roomCB)
if err != nil {
return fmt.Errorf("connect to room: %w", err)
}
s.room = room
s.setRoom(room)
if err := s.publishPendingTracks(); err != nil {
return err
}
s.wg.Add(1)
go s.processSendQueue()
return nil
}
func (s *Session) publishPendingTracks() error {
room := s.currentRoom()
if room == nil {
return ErrRoomNotConnected
}
s.videoTrackMu.RLock()
defer s.videoTrackMu.RUnlock()
for _, track := range s.videoTracks {
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: videoTrackName,
}); err != nil {
if err := room.publishTrack(track); err != nil {
return fmt.Errorf("failed to publish track: %w", err)
}
}
return nil
}
func (s *Session) startSendWorker() {
s.sendWorkerOnce.Do(func() {
s.wg.Add(1)
go s.processSendQueue()
})
}
func (s *Session) processSendQueue() {
defer s.wg.Done()
for {
@@ -157,17 +256,33 @@ func (s *Session) processSendQueue() {
if !ok {
return
}
if err := s.room.LocalParticipant.PublishDataPacket(
lksdk.UserData(data),
lksdk.WithDataPublishTopic(dataPublishTopic),
lksdk.WithDataPublishReliable(true),
); err != nil {
room := s.waitForConnectedRoom()
if room == nil {
return
}
if err := room.publishData(data); err != nil {
log.Printf("livekit publish data error: %v", err)
}
}
}
}
func (s *Session) waitForConnectedRoom() roomHandle {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
room := s.currentRoom()
if room != nil && room.connectionState() == lksdk.ConnectionStateConnected {
return room
}
select {
case <-s.done:
return nil
case <-ticker.C:
}
}
}
// Send queues data for transmission.
func (s *Session) Send(data []byte) error {
if s.closed.Load() {
@@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error {
// Close terminates the session.
func (s *Session) Close() error {
if s.closed.CompareAndSwap(false, true) {
s.cancel()
close(s.done)
if s.room != nil {
s.unpublishLocalTracks()
s.room.Disconnect()
// LiveKit's Disconnect() returns once the local SDK state
// is torn down, not when the server has actually evicted
// the participant. Without giving the signalling channel
// time to flush the LEAVE_REQUEST and the server to act on
// it, a back-to-back reconnect from the same identity in
// the same room sees a still-alive ghost participant on
// the SFU and inherits stale publication state.
time.Sleep(2 * time.Second)
}
close(s.sendQueue)
s.wg.Wait()
}
s.closed.Store(true)
s.shutdown()
return nil
}
func (s *Session) unpublishLocalTracks() {
if s.room == nil || s.room.LocalParticipant == nil {
return
}
for _, publication := range s.room.LocalParticipant.TrackPublications() {
if publication.SID() == "" {
continue
func (s *Session) shutdown() {
s.shutdownOnce.Do(func() {
if s.cancel != nil {
s.cancel()
}
if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
log.Printf("livekit unpublish track error: %v", err)
closeSignal(s.closeCh)
closeSignal(s.done)
if room := s.swapRoom(nil); room != nil {
room.unpublishLocalTracks()
room.disconnect()
}
}
s.wg.Wait()
})
}
// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity).
// SetReconnectCallback stores the reconnect callback.
func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb }
// SetShouldReconnect stores the reconnect predicate (kept for API parity).
// SetShouldReconnect stores the reconnect predicate.
func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn }
// SetEndedCallback registers a function to call when the session ends.
func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb }
// WatchConnection is a no-op; LiveKit handles connection supervision itself.
func (s *Session) WatchConnection(_ context.Context) {}
// WatchConnection monitors the connection lifecycle and reconnects as needed.
func (s *Session) WatchConnection(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-s.closeCh:
return
case <-s.reconnectCh:
if s.handleReconnectAttempt(ctx) {
return
}
}
}
}
func (s *Session) handleReconnectAttempt(ctx context.Context) bool {
if time.Since(s.lastReconnect) > reconnectWindow {
s.reconnectCount = 0
}
s.reconnectCount++
s.lastReconnect = time.Now()
if s.reconnectCount > maxReconnects {
s.signalEnded("reconnect limit reached")
return true
}
backoff := time.Duration(s.reconnectCount) * 2 * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
for {
if err := s.reconnect(ctx); err != nil {
logger.Debugf("livekit reconnect failed: %v", err)
select {
case <-ctx.Done():
return true
case <-s.closeCh:
return true
case <-time.After(backoff):
continue
}
}
s.drainReconnectQueue()
return false
}
}
func (s *Session) reconnect(ctx context.Context) error {
s.reconnecting.Store(true)
defer s.reconnecting.Store(false)
if room := s.swapRoom(nil); room != nil {
room.unpublishLocalTracks()
room.disconnect()
}
if s.refresh != nil {
creds, err := s.refresh(ctx)
if err != nil {
return fmt.Errorf("refresh credentials: %w", err)
}
s.applyRefreshedCredentials(creds)
}
if err := s.connectSession(ctx); err != nil {
return err
}
if s.onReconnect != nil {
s.onReconnect(nil)
}
return nil
}
func (s *Session) applyRefreshedCredentials(creds engine.Credentials) {
if creds.URL != "" {
s.url = creds.URL
}
if creds.Token != "" {
s.token = creds.Token
}
}
func (s *Session) queueReconnect() bool {
if s.closed.Load() || s.reconnecting.Load() {
return false
}
if s.shouldReconnect != nil && !s.shouldReconnect() {
return false
}
select {
case s.reconnectCh <- struct{}{}:
default:
}
return true
}
func (s *Session) drainReconnectQueue() {
for {
select {
case <-s.reconnectCh:
default:
return
}
}
}
func (s *Session) signalEnded(reason string) {
s.closed.Store(true)
s.shutdown()
if s.onEnded != nil {
s.onEnded(reason)
}
}
// CanSend reports whether the session is ready to accept data.
func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil }
func (s *Session) CanSend() bool {
if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard {
return false
}
room := s.currentRoom()
return room != nil && room.connectionState() == lksdk.ConnectionStateConnected
}
// GetSendQueue exposes the outbound queue.
func (s *Session) GetSendQueue() chan []byte { return s.sendQueue }
@@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error {
s.videoTracks = append(s.videoTracks, track)
s.videoTrackMu.Unlock()
if s.room == nil || s.room.LocalParticipant == nil {
room := s.currentRoom()
if room == nil {
return nil
}
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: videoTrackName,
}); err != nil {
if err := room.publishTrack(track); err != nil {
return fmt.Errorf("failed to publish track: %w", err)
}
return nil
@@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR
s.onVideoTrack = cb
}
func (s *Session) currentRoom() roomHandle {
s.roomMu.RLock()
defer s.roomMu.RUnlock()
return s.room
}
func (s *Session) setRoom(room roomHandle) {
s.roomMu.Lock()
defer s.roomMu.Unlock()
s.room = room
}
func (s *Session) swapRoom(room roomHandle) roomHandle {
s.roomMu.Lock()
defer s.roomMu.Unlock()
old := s.room
s.room = room
return old
}
func closeSignal(ch chan struct{}) {
select {
case <-ch:
default:
close(ch)
}
}
func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins
engine.Register("livekit", New)
}

View File

@@ -0,0 +1,306 @@
package livekit
import (
"context"
"errors"
"sync"
"testing"
"time"
lksdk "github.com/livekit/server-sdk-go/v2"
"github.com/openlibrecommunity/olcrtc/internal/engine"
"github.com/pion/webrtc/v4"
)
type fakeRoom struct {
mu sync.Mutex
state lksdk.ConnectionState
published [][]byte
tracks int
unpublished int
disconnected int
}
func newFakeRoom() *fakeRoom {
return &fakeRoom{state: lksdk.ConnectionStateConnected}
}
func (r *fakeRoom) publishData(data []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
r.published = append(r.published, append([]byte(nil), data...))
return nil
}
func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error {
r.mu.Lock()
defer r.mu.Unlock()
r.tracks++
return nil
}
func (r *fakeRoom) unpublishLocalTracks() {
r.mu.Lock()
defer r.mu.Unlock()
r.unpublished++
}
func (r *fakeRoom) disconnect() {
r.mu.Lock()
defer r.mu.Unlock()
r.disconnected++
r.state = lksdk.ConnectionStateDisconnected
}
func (r *fakeRoom) connectionState() lksdk.ConnectionState {
r.mu.Lock()
defer r.mu.Unlock()
return r.state
}
type fakeConnector struct {
mu sync.Mutex
urls []string
tokens []string
callbacks []*lksdk.RoomCallback
rooms []*fakeRoom
connected chan struct{}
err error
}
func newFakeConnector() *fakeConnector {
return &fakeConnector{connected: make(chan struct{}, 8)}
}
func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
return nil, c.err
}
room := newFakeRoom()
c.urls = append(c.urls, url)
c.tokens = append(c.tokens, token)
c.callbacks = append(c.callbacks, cb)
c.rooms = append(c.rooms, room)
c.connected <- struct{}{}
return room, nil
}
func (c *fakeConnector) count() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.rooms)
}
func (c *fakeConnector) callback(i int) *lksdk.RoomCallback {
c.mu.Lock()
defer c.mu.Unlock()
return c.callbacks[i]
}
func (c *fakeConnector) room(i int) *fakeRoom {
c.mu.Lock()
defer c.mu.Unlock()
return c.rooms[i]
}
func (c *fakeConnector) snapshot() ([]string, []string) {
c.mu.Lock()
defer c.mu.Unlock()
return append([]string(nil), c.urls...), append([]string(nil), c.tokens...)
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("condition was not met before timeout")
}
func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
refreshes := 0
sess, err := New(ctx, engine.Config{
URL: "wss://old",
Token: "old-token",
Refresh: func(context.Context) (engine.Credentials, error) {
refreshes++
return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil
},
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
s := sess.(*Session)
connector := newFakeConnector()
s.connectRoom = connector.connect
reconnected := make(chan struct{}, 1)
s.SetReconnectCallback(func(*webrtc.DataChannel) {
reconnected <- struct{}{}
})
if err := s.Connect(ctx); err != nil {
t.Fatalf("Connect() error = %v", err)
}
go s.WatchConnection(ctx)
connector.callback(0).OnDisconnected()
waitFor(t, func() bool { return connector.count() == 2 })
select {
case <-reconnected:
case <-time.After(time.Second):
t.Fatal("reconnect callback was not called")
}
urls, tokens := connector.snapshot()
if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) {
t.Fatalf("connect urls = %v, want %v", got, want)
}
if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) {
t.Fatalf("connect tokens = %v, want %v", got, want)
}
if refreshes != 1 {
t.Fatalf("refreshes = %d, want 1", refreshes)
}
oldRoom := connector.room(0)
oldRoom.mu.Lock()
if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 {
t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1",
oldRoom.disconnected, oldRoom.unpublished)
}
oldRoom.mu.Unlock()
if !s.CanSend() {
t.Fatal("CanSend() = false after reconnect, want true")
}
if err := s.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
}
func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) {
ctx := context.Background()
sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
s := sess.(*Session)
connector := newFakeConnector()
s.connectRoom = connector.connect
s.SetShouldReconnect(func() bool { return false })
ended := make(chan string, 1)
s.SetEndedCallback(func(reason string) {
ended <- reason
})
if err := s.Connect(ctx); err != nil {
t.Fatalf("Connect() error = %v", err)
}
connector.callback(0).OnDisconnected()
select {
case reason := <-ended:
if reason != "disconnected from livekit" {
t.Fatalf("ended reason = %q, want disconnected from livekit", reason)
}
case <-time.After(time.Second):
t.Fatal("ended callback was not called")
}
if !s.closed.Load() {
t.Fatal("closed = false after terminal disconnect")
}
if connector.count() != 1 {
t.Fatalf("connect count = %d, want 1", connector.count())
}
room := connector.room(0)
room.mu.Lock()
if room.disconnected != 1 || room.unpublished != 1 {
t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1",
room.disconnected, room.unpublished)
}
room.mu.Unlock()
if err := s.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
room.mu.Lock()
if room.disconnected != 1 || room.unpublished != 1 {
t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1",
room.disconnected, room.unpublished)
}
room.mu.Unlock()
}
func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) {
s := &Session{
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
closeCh: make(chan struct{}),
}
if s.CanSend() {
t.Fatal("CanSend() = true without room")
}
room := newFakeRoom()
room.state = lksdk.ConnectionStateDisconnected
s.setRoom(room)
if s.CanSend() {
t.Fatal("CanSend() = true for disconnected room")
}
room.state = lksdk.ConnectionStateConnected
if !s.CanSend() {
t.Fatal("CanSend() = false for connected room")
}
for i := 0; i < defaultSendQueueCapHard; i++ {
s.sendQueue <- []byte("x")
}
if s.CanSend() {
t.Fatal("CanSend() = true at queue high watermark")
}
}
func TestReconnectFailureRetriesUntilContextDone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Session{
url: "wss://old",
token: "old-token",
connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) {
cancel()
return nil, errors.New("boom")
},
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, defaultSendQueueSize),
done: make(chan struct{}),
}
if terminal := s.handleReconnectAttempt(ctx); !terminal {
t.Fatal("handleReconnectAttempt() = false after context cancellation")
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration)
}
func (s *Session) dialWebSocket() error {
wsDialer := websocket.Dialer{
NetDialContext: protect.DialContext,
HandshakeTimeout: wsHandshakeTimeout,
}
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
ws, resp, err := wsDialer.Dial(s.connectorURL, nil)
if err != nil {

View File

@@ -13,8 +13,8 @@
// │ │
//
// After the exchange the control stream stays open; tunnel traffic flows over
// additional smux streams opened by the client. The control stream may carry
// keepalives or future control messages.
// additional smux streams opened by the client. The control stream then
// carries ping/pong liveness and future control messages.
//
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
package handshake

View File

@@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return nil, fmt.Errorf("create transport for direct link: %w", err)
@@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) {
d.transport.WatchConnection(ctx)
}
func (d *directLink) CanSend() bool { return d.transport.CanSend() }
// Features reports the direct link's underlying transport capabilities.
func (d *directLink) Features() link.Features { return d.transport.Features() }

View File

@@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
VideoTileRS: 20,
VP8FPS: 25,
VP8BatchSize: 8,
Traffic: transport.TrafficConfig{MaxPayloadSize: 4096},
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 ||
seen.Traffic.MaxPayloadSize != 4096 {
t.Fatalf("forwarded config = %+v", seen)
}
@@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 {
t.Fatalf("Features() = %+v, want shaped max payload 4096", features)
}
}
func TestNewWrapsFactoryError(t *testing.T) {

View File

@@ -4,6 +4,8 @@ package link
import (
"context"
"errors"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
var (
@@ -23,11 +25,19 @@ type Link interface {
CanSend() bool
}
// Features mirrors the underlying transport capabilities when a link can expose them.
type Features = transport.Features
// FeaturesProvider is optionally implemented by links that can report wire limits.
type FeaturesProvider interface {
Features() Features
}
// Config holds common link configuration.
type Config struct {
Transport string
Carrier string
RoomURL string
Transport string
Carrier string
RoomURL string
// Engine, URL, Token are forwarded for the "none" auth carrier.
Engine string
URL string
@@ -54,6 +64,7 @@ type Config struct {
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Traffic transport.TrafficConfig
}
// Factory creates a link instance.

View File

@@ -3,13 +3,38 @@ package protect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"regexp"
"strings"
"syscall"
"time"
"github.com/gorilla/websocket"
)
const (
defaultDialTimeout = 10 * time.Second
defaultKeepAlive = 30 * time.Second
defaultIdleConnTimeout = 30 * time.Second
defaultTLSHandshake = 10 * time.Second
defaultResponseHeader = 10 * time.Second
defaultWebSocketTimeout = 10 * time.Second
defaultHTTPClientTimeout = 30 * time.Second
defaultStatusBodyLimit = 1024
)
var (
sensitiveFieldRE = regexp.MustCompile(
`(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` +
`[^",\s}]+`,
)
sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`)
) //nolint:gochecknoglobals // compiled once for provider error redaction
// Protector is called with a socket file descriptor before connect.
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional
@@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error {
// NewDialer returns a net.Dialer that calls Protector on each new socket.
func NewDialer() *net.Dialer {
return &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Timeout: defaultDialTimeout,
KeepAlive: defaultKeepAlive,
Control: controlFunc,
}
}
// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients.
func NewTLSConfig() *tls.Config {
return &tls.Config{MinVersion: tls.VersionTLS12}
}
// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts.
func NewHTTPTransport() *http.Transport {
dialer := NewDialer()
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
TLSClientConfig: NewTLSConfig(),
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: defaultIdleConnTimeout,
TLSHandshakeTimeout: defaultTLSHandshake,
ResponseHeaderTimeout: defaultResponseHeader,
}
}
// NewHTTPClient returns an http.Client using protected sockets.
func NewHTTPClient() *http.Client {
dialer := NewDialer()
transport := &http.Transport{
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
return &http.Client{
Transport: NewHTTPTransport(),
Timeout: defaultHTTPClientTimeout,
}
return &http.Client{Transport: transport}
}
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
if handshakeTimeout <= 0 {
handshakeTimeout = defaultWebSocketTimeout
}
return websocket.Dialer{
NetDialContext: DialContext,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: NewTLSConfig(),
HandshakeTimeout: handshakeTimeout,
}
}
// StatusError formats an upstream HTTP error while bounding and redacting the body.
func StatusError(base error, resp *http.Response, limit int64) error {
if limit <= 0 {
limit = defaultStatusBodyLimit
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
bodyText := RedactSensitive(strings.TrimSpace(string(body)))
if bodyText == "" {
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
}
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
}
// RedactSensitive removes common token-like values from provider error text.
func RedactSensitive(text string) string {
text = sensitiveBearerRE.ReplaceAllString(text, "${1}<redacted>")
return sensitiveFieldRE.ReplaceAllString(text, "${1}<redacted>")
}
// DialContext dials using a protected socket.

View File

@@ -2,9 +2,11 @@ package protect
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
"syscall"
"testing"
"time"
@@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
if !ok {
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
}
if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second ||
tr.ResponseHeaderTimeout != 10*time.Second {
tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second {
t.Fatalf("transport = %+v", tr)
}
}
func TestNewWebSocketDialer(t *testing.T) {
dialer := NewWebSocketDialer(3 * time.Second)
if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil ||
dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 ||
dialer.HandshakeTimeout != 3*time.Second {
t.Fatalf("NewWebSocketDialer() = %+v", dialer)
}
defaulted := NewWebSocketDialer(0)
if defaulted.HandshakeTimeout != defaultWebSocketTimeout {
t.Fatalf("default HandshakeTimeout = %v, want %v",
defaulted.HandshakeTimeout, defaultWebSocketTimeout)
}
}
func TestStatusErrorRedactsAndLimitsBody(t *testing.T) {
resp := &http.Response{
StatusCode: http.StatusForbidden,
Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)},
}
err := StatusError(errProtectBoom, resp, 1024)
if err == nil {
t.Fatal("StatusError() error = nil")
}
text := err.Error()
if strings.Contains(text, "secret") || !strings.Contains(text, "<redacted>") {
t.Fatalf("StatusError() = %q, want redacted token", text)
}
}
func TestRedactSensitiveBearer(t *testing.T) {
got := RedactSensitive("Authorization: Bearer abc.def")
if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer <redacted>") {
t.Fatalf("RedactSensitive() = %q", got)
}
}
type ioNopCloser struct {
*strings.Reader
}
func (c ioNopCloser) Close() error { return nil }
func TestDialContextAndProxyDialer(t *testing.T) {
var lc net.ListenConfig
ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0")

View File

@@ -14,12 +14,14 @@ import (
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
@@ -49,25 +51,33 @@ type SessionCloseFunc func(sessionID, reason string)
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
// HealthFunc is called when the server control health snapshot changes.
type HealthFunc func(control.Status)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
ln link.Link
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStop context.CancelFunc
sessMu sync.RWMutex
reinstallMu sync.Mutex
healthMu sync.RWMutex
wg sync.WaitGroup
authHook handshake.AuthFunc
onOpen SessionOpenFunc
onClose SessionCloseFunc
onTraffic TrafficFunc
onHealth HealthFunc
deviceID string
sessionID string
dnsServer string
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
liveness control.Config
health control.Status
}
// ConnectRequest is a message from the client to establish a new connection.
@@ -106,6 +116,8 @@ type Config struct {
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
@@ -117,6 +129,8 @@ type Config struct {
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
OnTraffic TrafficFunc
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the server with the given configuration.
@@ -145,6 +159,10 @@ func Run(ctx context.Context, cfg Config) error {
if onTraffic == nil {
onTraffic = func(string, string, uint64, uint64) {}
}
onHealth := cfg.OnHealth
if onHealth == nil {
onHealth = func(control.Status) {}
}
s := &Server{
cipher: cipher,
@@ -152,9 +170,11 @@ func Run(ctx context.Context, cfg Config) error {
onOpen: onOpen,
onClose: onClose,
onTraffic: onTraffic,
onHealth: onHealth,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
liveness: cfg.Liveness,
}
s.setupResolver()
@@ -216,11 +236,17 @@ func (s *Server) setupResolver() {
// smuxConfig mirrors the client side. Both peers must agree on Version and
// MaxFrameSize.
func smuxConfig() *smux.Config {
func smuxConfig(maxWirePayload ...int) *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.MaxFrameSize = 32768
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
if maxFrameSize < cfg.MaxFrameSize {
cfg.MaxFrameSize = maxFrameSize
}
}
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
@@ -228,6 +254,14 @@ func smuxConfig() *smux.Config {
return cfg
}
func linkMaxPayload(ln link.Link) int {
provider, ok := ln.(link.FeaturesProvider)
if !ok {
return 0
}
return provider.Features().MaxPayloadSize
}
func (s *Server) bringUpLink(
ctx context.Context,
cfg Config,
@@ -262,6 +296,7 @@ func (s *Server) bringUpLink(
SEIBatchSize: cfg.SEIBatchSize,
SEIFragmentSize: cfg.SEIFragmentSize,
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
@@ -298,7 +333,7 @@ func (s *Server) bringUpLink(
func (s *Server) installSession() {
conn := muxconn.New(s.ln, s.cipher)
sess, err := smux.Server(conn, smuxConfig())
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
return
@@ -310,7 +345,8 @@ func (s *Server) installSession() {
}
func (s *Server) handleReconnect() {
logger.Infof("server link reconnect - tearing down smux session")
s.recordReconnect()
logger.Infof("server reconnect reason=carrier - tearing down smux session")
s.sessMu.RLock()
current := s.session
s.sessMu.RUnlock()
@@ -323,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) {
// Pre-build the replacement so we can swap atomically below.
newConn := muxconn.New(s.ln, s.cipher)
newSess, err := smux.Server(newConn, smuxConfig())
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
_ = newConn.Close()
@@ -340,13 +376,18 @@ func (s *Server) reinstallSession(dead *smux.Session) {
}
oldSess := s.session
oldConn := s.conn
oldControlStop := s.controlStop
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.controlStop = nil
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldSess != nil {
_ = oldSess.Close()
}
@@ -362,13 +403,18 @@ func (s *Server) closeSession() {
s.sessMu.Lock()
sess := s.session
conn := s.conn
controlStop := s.controlStop
s.session = nil
s.conn = nil
s.controlStop = nil
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
if conn != nil {
_ = conn.Close()
}
@@ -476,27 +522,120 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
// The control stream stays open for the lifetime of the session;
// keep it parked in a goroutine so the smux session does not close it.
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.parkControlStream(stream)
}()
s.startControlLoop(ctx, sess, stream)
return true
}
// parkControlStream blocks reading from the control stream until it closes.
// Future control messages (kick, rate updates, etc.) would be dispatched here.
func (s *Server) parkControlStream(stream *smux.Stream) {
defer func() { _ = stream.Close() }()
buf := make([]byte, 64)
for {
if _, err := stream.Read(buf); err != nil {
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
s.sessMu.Lock()
s.controlStop = stop
s.sessMu.Unlock()
liveness := s.liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
s.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
s.recordMissed(missed)
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
s.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("server control stream ended: %v", err)
}
s.recordReconnect()
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
s.reinstallSession(sess)
}()
}
// Status returns the latest server-side control health snapshot.
func (s *Server) Status() control.Status {
s.healthMu.RLock()
defer s.healthMu.RUnlock()
return s.health
}
func (s *Server) recordSession(sessionID string) {
s.healthMu.Lock()
s.health.SessionID = sessionID
s.health.MissedPongs = 0
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordPong(h control.Health) {
s.healthMu.Lock()
s.health.LastPong = h.LastSeen
s.health.LastRTT = h.RTT
s.health.MissedPongs = 0
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordMissed(missed int) {
s.healthMu.Lock()
s.health.MissedPongs = missed
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordUnhealthy(missed int) {
s.healthMu.Lock()
s.health.MissedPongs = missed
s.health.UnhealthyEvents++
s.health.LastUnhealthy = time.Now()
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) recordReconnect() {
s.healthMu.Lock()
s.health.Reconnects++
status := s.health
s.healthMu.Unlock()
s.notifyHealth(status)
}
func (s *Server) notifyHealth(status control.Status) {
if s.onHealth != nil {
s.onHealth(status)
}
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/control"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
@@ -49,6 +50,11 @@ func TestSmuxConfig(t *testing.T) {
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
capped := smuxConfig(4096)
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
}
}
func TestParseConnectRequest(t *testing.T) {
@@ -373,6 +379,103 @@ func TestReinstallSessionFiresOnClose(t *testing.T) {
}
}
func TestStartControlLoopReportsPong(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
serverStreamCh := make(chan *smux.Stream, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
serverStreamCh <- stream
}
}()
clientStream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
serverStream := <-serverStreamCh
ctx, cancel := context.WithCancel(context.Background())
got := make(chan control.Health, 1)
s := &Server{
sessionID: "sid-control",
liveness: control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
OnPong: func(h control.Health) {
select {
case got <- h:
default:
}
},
},
}
s.recordSession("sid-control")
defer func() {
cancel()
s.wg.Wait()
}()
s.startControlLoop(ctx, serverSess, serverStream)
go func() {
_ = control.Run(ctx, clientStream, control.Config{
Interval: 10 * time.Millisecond,
Timeout: 100 * time.Millisecond,
Failures: 2,
})
}()
select {
case h := <-got:
if h.Seq == 0 {
t.Fatal("Health.Seq = 0")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for control pong")
}
status := s.Status()
if status.SessionID != "sid-control" {
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
}
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
t.Fatalf("Status() = %+v", status)
}
}
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
updates := 0
s := &Server{onHealth: func(control.Status) { updates++ }}
s.recordSession("sid-1")
s.recordMissed(2)
s.recordUnhealthy(3)
s.recordReconnect()
status := s.Status()
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
t.Fatalf("Status() = %+v", status)
}
if updates != 4 {
t.Fatalf("health updates = %d, want 4", updates)
}
}
//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together.
func TestDispatchFiresOnTraffic(t *testing.T) {
var lc net.ListenConfig

View File

@@ -0,0 +1,229 @@
// Package supervisor runs ordered session profiles with failover.
package supervisor
import (
"context"
"errors"
"fmt"
"time"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
)
const DefaultRetryDelay = 2 * time.Second
const DefaultHistoryLimit = 20
const (
// EventProfileStart marks a profile attempt starting.
EventProfileStart = "profile_start"
// EventProfileEnd marks a profile attempt ending.
EventProfileEnd = "profile_end"
)
var (
// ErrNoProfiles is returned when the supervisor is started without profiles.
ErrNoProfiles = errors.New("supervisor: no profiles configured")
// ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes.
ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded")
)
// Profile is one runnable session configuration in an ordered failover list.
type Profile struct {
Name string
Config session.Config
}
// ProfileStatus summarizes one profile's failover history.
type ProfileStatus struct {
Name string
Starts int
Failures int
CleanEnds int
LastStarted time.Time
LastEnded time.Time
LastError string
}
// Event is one bounded failover history entry.
type Event struct {
Time time.Time
Type string
Profile string
Cycle int
Error string
}
// Status is a point-in-time view of the supervisor.
type Status struct {
Cycle int
ActiveProfile string
ActiveProfileIndex int
Profiles []ProfileStatus
History []Event
LastError string
}
// Runner starts one session profile and blocks until it ends or fails.
type Runner func(ctx context.Context, cfg session.Config) error
// Config controls ordered failover behavior.
type Config struct {
Profiles []Profile
RetryDelay time.Duration
MaxCycles int
OnProfileStart func(profile Profile, cycle int)
OnProfileEnd func(profile Profile, cycle int, err error)
OnStatus func(status Status)
HistoryLimit int
}
// Run starts profiles in order. If a profile exits while ctx is still active,
// the supervisor waits RetryDelay and advances to the next profile.
func Run(ctx context.Context, cfg Config, run Runner) error {
if len(cfg.Profiles) == 0 {
return ErrNoProfiles
}
if cfg.RetryDelay == 0 {
cfg.RetryDelay = DefaultRetryDelay
}
state := newStatusTracker(cfg.Profiles, cfg.HistoryLimit, cfg.OnStatus)
var lastErr error
for cycle := 1; ; cycle++ {
for i, profile := range cfg.Profiles {
if ctx.Err() != nil {
return nil
}
state.start(i, cycle)
if cfg.OnProfileStart != nil {
cfg.OnProfileStart(profile, cycle)
}
err := run(ctx, profile.Config)
if ctx.Err() != nil {
return nil
}
if err != nil {
lastErr = fmt.Errorf("profile %q: %w", profile.Name, err)
} else {
lastErr = fmt.Errorf("profile %q ended", profile.Name)
}
state.end(i, cycle, err)
if cfg.OnProfileEnd != nil {
cfg.OnProfileEnd(profile, cycle, err)
}
if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 {
return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr)
}
if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil {
return nil
}
}
}
}
type statusTracker struct {
status Status
notify func(Status)
historyLimit int
}
func newStatusTracker(profiles []Profile, historyLimit int, notify func(Status)) *statusTracker {
if historyLimit == 0 {
historyLimit = DefaultHistoryLimit
}
statusProfiles := make([]ProfileStatus, 0, len(profiles))
for _, profile := range profiles {
statusProfiles = append(statusProfiles, ProfileStatus{Name: profile.Name})
}
return &statusTracker{
status: Status{
ActiveProfileIndex: -1,
Profiles: statusProfiles,
},
notify: notify,
historyLimit: historyLimit,
}
}
func (t *statusTracker) start(profileIndex, cycle int) {
now := time.Now()
profile := &t.status.Profiles[profileIndex]
profile.Starts++
profile.LastStarted = now
t.status.Cycle = cycle
t.status.ActiveProfile = profile.Name
t.status.ActiveProfileIndex = profileIndex
t.appendHistory(Event{
Time: now,
Type: EventProfileStart,
Profile: profile.Name,
Cycle: cycle,
})
t.emit()
}
func (t *statusTracker) end(profileIndex, cycle int, err error) {
now := time.Now()
profile := &t.status.Profiles[profileIndex]
profile.LastEnded = now
event := Event{
Time: now,
Type: EventProfileEnd,
Profile: profile.Name,
Cycle: cycle,
}
if err != nil {
profile.Failures++
profile.LastError = err.Error()
t.status.LastError = fmt.Sprintf("profile %q: %v", profile.Name, err)
event.Error = err.Error()
} else {
profile.CleanEnds++
profile.LastError = ""
t.status.LastError = fmt.Sprintf("profile %q ended", profile.Name)
}
t.status.ActiveProfile = ""
t.status.ActiveProfileIndex = -1
t.appendHistory(event)
t.emit()
}
func (t *statusTracker) appendHistory(event Event) {
if t.historyLimit < 0 {
return
}
t.status.History = append(t.status.History, event)
if len(t.status.History) > t.historyLimit {
t.status.History = t.status.History[len(t.status.History)-t.historyLimit:]
}
}
func (t *statusTracker) emit() {
if t.notify == nil {
return
}
t.notify(cloneStatus(t.status))
}
func cloneStatus(status Status) Status {
status.Profiles = append([]ProfileStatus(nil), status.Profiles...)
status.History = append([]Event(nil), status.History...)
return status
}
func waitRetryDelay(ctx context.Context, delay time.Duration) error {
if delay <= 0 {
return nil
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}

View File

@@ -0,0 +1,170 @@
package supervisor
import (
"context"
"errors"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
)
var errRunnerBoom = errors.New("boom")
func TestRunRequiresProfiles(t *testing.T) {
err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil })
if !errors.Is(err, ErrNoProfiles) {
t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles)
}
}
func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) {
profiles := []Profile{
{Name: "first", Config: session.Config{Auth: "wbstream"}},
{Name: "second", Config: session.Config{Auth: "jitsi"}},
}
var started []string
var ended []string
err := Run(context.Background(), Config{
Profiles: profiles,
RetryDelay: -1,
MaxCycles: 1,
OnProfileStart: func(profile Profile, cycle int) {
started = append(started, profile.Name)
if cycle != 1 {
t.Fatalf("cycle = %d, want 1", cycle)
}
},
OnProfileEnd: func(profile Profile, _ int, err error) {
ended = append(ended, profile.Name)
if !errors.Is(err, errRunnerBoom) {
t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom)
}
},
}, func(_ context.Context, cfg session.Config) error {
if cfg.Auth == "" {
t.Fatal("runner received empty auth")
}
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if got, want := started, []string{"first", "second"}; !equalStrings(got, want) {
t.Fatalf("started = %v, want %v", got, want)
}
if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) {
t.Fatalf("ended = %v, want %v", got, want)
}
}
func TestRunEmitsStatusHistory(t *testing.T) {
profiles := []Profile{
{Name: "first", Config: session.Config{Auth: "wbstream"}},
{Name: "second", Config: session.Config{Auth: "jitsi"}},
}
var snapshots []Status
err := Run(context.Background(), Config{
Profiles: profiles,
RetryDelay: -1,
MaxCycles: 1,
HistoryLimit: 3,
OnStatus: func(status Status) {
snapshots = append(snapshots, status)
},
}, func(_ context.Context, cfg session.Config) error {
if cfg.Auth == "first" {
t.Fatal("runner received profile name instead of config")
}
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if len(snapshots) != 4 {
t.Fatalf("status snapshots = %d, want 4", len(snapshots))
}
first := snapshots[0]
if first.ActiveProfile != "first" || first.ActiveProfileIndex != 0 || first.Cycle != 1 {
t.Fatalf("first status = %+v", first)
}
if first.Profiles[0].Starts != 1 || first.Profiles[0].LastStarted.IsZero() {
t.Fatalf("first profile start status = %+v", first.Profiles[0])
}
last := snapshots[len(snapshots)-1]
if last.ActiveProfile != "" || last.ActiveProfileIndex != -1 {
t.Fatalf("last active status = %+v", last)
}
if last.Profiles[0].Failures != 1 || last.Profiles[1].Failures != 1 {
t.Fatalf("profile failures = %+v", last.Profiles)
}
if last.LastError == "" || last.Profiles[1].LastError == "" {
t.Fatalf("last errors missing: %+v", last)
}
if len(last.History) != 3 {
t.Fatalf("history length = %d, want 3", len(last.History))
}
if last.History[0].Type != EventProfileEnd || last.History[0].Profile != "first" {
t.Fatalf("oldest bounded history event = %+v", last.History[0])
}
if last.History[2].Type != EventProfileEnd || last.History[2].Profile != "second" ||
last.History[2].Error == "" {
t.Fatalf("last history event = %+v", last.History[2])
}
}
func TestRunStatusSnapshotIsImmutable(t *testing.T) {
var first Status
var second Status
err := Run(context.Background(), Config{
Profiles: []Profile{{Name: "one"}},
RetryDelay: -1,
MaxCycles: 1,
OnStatus: func(status Status) {
if first.Profiles == nil {
first = status
first.Profiles[0].Starts = 99
first.History[0].Profile = "mutated"
return
}
second = status
},
}, func(context.Context, session.Config) error {
return errRunnerBoom
})
if !errors.Is(err, ErrMaxCyclesExceeded) {
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
}
if first.Profiles[0].Starts != 99 || first.History[0].Profile != "mutated" {
t.Fatalf("test mutation did not apply to snapshot: %+v", first)
}
if second.Profiles[0].Starts != 1 || second.History[0].Profile != "one" {
t.Fatalf("snapshot mutation leaked into later status: %+v", second)
}
}
func TestRunReturnsNilOnContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
err := Run(ctx, Config{
Profiles: []Profile{{Name: "one"}},
RetryDelay: time.Hour,
}, func(context.Context, session.Config) error {
cancel()
return nil
})
if err != nil {
t.Fatalf("Run() error = %v, want nil", err)
}
}
func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -35,6 +35,7 @@ const (
protocolVersion byte = 1
frameTypeData byte = 1
frameTypeAck byte = 2
frameTypeHello byte = 3
)
var (
@@ -86,6 +87,7 @@ type streamTransport struct {
nextSeq atomic.Uint32
closed atomic.Bool
writerUp atomic.Bool
peerReady atomic.Bool
sendMu sync.Mutex
startWriter sync.Once
ackMu sync.Mutex
@@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) {
// CanSend reports whether transport is ready for sending.
func (p *streamTransport) CanSend() bool {
return !p.closed.Load() && p.stream.CanSend()
return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend()
}
// Features describes the current seichannel transport semantics.
@@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() {
ticker := time.NewTicker(p.effectiveFrameInterval())
defer ticker.Stop()
idle := buildVideoAccessUnit(nil)
idle := buildVideoAccessUnit(encodeHelloFrame())
for {
select {
@@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) {
}
switch frame.typ {
case frameTypeHello:
p.peerReady.Store(true)
case frameTypeAck:
p.peerReady.Store(true)
p.resolveAck(frame.seq, frame.crc)
case frameTypeData:
p.peerReady.Store(true)
p.handleInboundFrame(frame)
}
}
@@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
out[5] = frameTypeData
binary.BigEndian.PutUint32(out[6:10], seq)
binary.BigEndian.PutUint32(out[10:14], crc)
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
copy(out[22:], payload)
return out
@@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte {
return out
}
func encodeHelloFrame() []byte {
out := make([]byte, 6)
binary.BigEndian.PutUint32(out[0:4], protocolMagic)
out[4] = protocolVersion
out[5] = frameTypeHello
return out
}
func decodeTransportFrame(data []byte) (transportFrame, error) {
if len(data) < 6 {
return transportFrame{}, ErrFrameTooShort
@@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
frame := transportFrame{typ: data[5]}
switch frame.typ {
case frameTypeHello:
return frame, nil
case frameTypeAck:
if len(data) < 14 {
return transportFrame{}, ErrAckTooShort

View File

@@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) {
t.Fatalf("payload mismatch: got=%q", decoded.payload)
}
}
func TestHelloFrameRoundTrip(t *testing.T) {
hello, err := decodeTransportFrame(encodeHelloFrame())
if err != nil {
t.Fatalf("decodeTransportFrame(hello) failed: %v", err)
}
if hello.typ != frameTypeHello {
t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello)
}
}

View File

@@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) {
if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched {
t.Fatal("callbacks/watch were not forwarded")
}
if tr.CanSend() {
t.Fatal("CanSend() = true before peer hello")
}
tr.handleSample(buildVideoAccessUnit(encodeHelloFrame()))
if !tr.CanSend() {
t.Fatal("CanSend() = false, want true")
t.Fatal("CanSend() = false after peer hello")
}
if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description
t.Fatalf("Features() = %+v", features)

View File

@@ -0,0 +1,91 @@
package transport
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"sync"
"time"
)
var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size")
type trafficTransport struct {
inner Transport
maxPayloadSize int
minDelay time.Duration
maxDelay time.Duration
sendMu sync.Mutex
}
// WithTraffic wraps tr with optional payload caps and send pacing.
func WithTraffic(tr Transport, cfg TrafficConfig) Transport {
if tr == nil {
return nil
}
cfg = effectiveTrafficConfig(tr.Features(), cfg)
if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 {
return tr
}
return &trafficTransport{
inner: tr,
maxPayloadSize: cfg.MaxPayloadSize,
minDelay: cfg.MinDelay,
maxDelay: cfg.MaxDelay,
}
}
func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig {
if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize {
cfg.MaxPayloadSize = features.MaxPayloadSize
}
return cfg
}
func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) }
func (t *trafficTransport) Send(data []byte) error {
t.sendMu.Lock()
defer t.sendMu.Unlock()
if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize {
return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize)
}
if delay := t.nextDelay(); delay > 0 {
time.Sleep(delay)
}
return t.inner.Send(data)
}
func (t *trafficTransport) Close() error { return t.inner.Close() }
func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) }
func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) }
func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) }
func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) }
func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() }
func (t *trafficTransport) Features() Features {
features := t.inner.Features()
if t.maxPayloadSize > 0 &&
(features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) {
features.MaxPayloadSize = t.maxPayloadSize
}
return features
}
func (t *trafficTransport) nextDelay() time.Duration {
if t.maxDelay <= 0 && t.minDelay <= 0 {
return 0
}
minDelay := t.minDelay
maxDelay := t.maxDelay
if maxDelay <= minDelay {
return minDelay
}
return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter
}

View File

@@ -0,0 +1,67 @@
package transport
import (
"context"
"errors"
"testing"
"time"
)
type trafficStubTransport struct {
features Features
sent [][]byte
}
func (s *trafficStubTransport) Connect(context.Context) error { return nil }
func (s *trafficStubTransport) Send(data []byte) error {
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
func (s *trafficStubTransport) Close() error { return nil }
func (s *trafficStubTransport) SetReconnectCallback(func()) {}
func (s *trafficStubTransport) SetShouldReconnect(func() bool) {}
func (s *trafficStubTransport) SetEndedCallback(func(string)) {}
func (s *trafficStubTransport) WatchConnection(context.Context) {}
func (s *trafficStubTransport) CanSend() bool { return true }
func (s *trafficStubTransport) Features() Features { return s.features }
func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) {
inner := &trafficStubTransport{}
got := WithTraffic(inner, TrafficConfig{})
if got != inner {
t.Fatalf("WithTraffic disabled returned %T, want inner", got)
}
}
func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) {
inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}}
tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10})
if features := tr.Features(); features.MaxPayloadSize != 5 {
t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize)
}
err := tr.Send([]byte("123456"))
if !errors.Is(err, ErrTrafficPayloadTooLarge) {
t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge)
}
if len(inner.sent) != 0 {
t.Fatalf("inner sent %d payloads, want 0", len(inner.sent))
}
if err := tr.Send([]byte("12345")); err != nil {
t.Fatalf("Send(max sized) error = %v", err)
}
if got := string(inner.sent[0]); got != "12345" {
t.Fatalf("inner payload = %q, want 12345", got)
}
}
func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) {
inner := &trafficStubTransport{}
tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond})
start := time.Now()
if err := tr.Send([]byte("x")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if elapsed := time.Since(start); elapsed < 2*time.Millisecond {
t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed)
}
}

View File

@@ -4,6 +4,7 @@ package transport
import (
"context"
"errors"
"time"
)
var (
@@ -32,10 +33,17 @@ type Transport interface {
Features() Features
}
// TrafficConfig controls optional reliability-oriented send shaping.
type TrafficConfig struct {
MaxPayloadSize int
MinDelay time.Duration
MaxDelay time.Duration
}
// Config holds common transport configuration.
type Config struct {
Carrier string
RoomURL string
Carrier string
RoomURL string
// Engine, URL, Token are forwarded to carrier.Config for the "none" auth
// carrier (direct engine access without a service-specific auth flow).
Engine string
@@ -63,6 +71,7 @@ type Config struct {
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
Traffic TrafficConfig
}
// Factory creates a transport instance.
@@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) {
if !ok {
return nil, ErrTransportNotFound
}
return factory(ctx, cfg)
tr, err := factory(ctx, cfg)
if err != nil {
return nil, err
}
return WithTraffic(tr, cfg.Traffic), nil
}
// Available returns a list of registered transport names.