mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
Merge pull request #58 from cyber-debug/refine/livekit-reconnect
refine livekit reconnect and liveness
This commit is contained in:
@@ -5,13 +5,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/auth"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/carrier/builtin"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link/direct"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
@@ -37,18 +41,35 @@ const (
|
||||
videoCodecTile = "tile"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultVideoWidth = 1920
|
||||
defaultVideoHeight = 1080
|
||||
defaultVideoFPS = 30
|
||||
defaultVideoBitrate = "2M"
|
||||
defaultVideoHW = "none"
|
||||
defaultVideoQRRecovery = "low"
|
||||
defaultVP8FPS = 25
|
||||
defaultVP8BatchSize = 1
|
||||
defaultSEIFPS = 60
|
||||
defaultSEIBatchSize = 64
|
||||
defaultSEIFragmentSize = 900
|
||||
defaultSEIAckTimeoutMS = 2000
|
||||
)
|
||||
|
||||
var sessionRestartDelay = 2 * time.Second
|
||||
|
||||
var (
|
||||
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
|
||||
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
|
||||
ErrRoomIDRequired = errors.New("room ID required (set room.id)")
|
||||
// ErrModeRequired indicates that mode is not one of the supported values.
|
||||
ErrModeRequired = errors.New("mode required (use -mode srv, -mode cnc or -mode gen)")
|
||||
// ErrAmountRequired indicates that -amount is required for gen mode.
|
||||
ErrAmountRequired = errors.New("amount required for gen mode (use -amount <n>)")
|
||||
ErrModeRequired = errors.New("mode required (set mode to srv, cnc or gen)")
|
||||
// ErrAmountRequired indicates that gen.amount is required for gen mode.
|
||||
ErrAmountRequired = errors.New("amount required for gen mode (set gen.amount)")
|
||||
// ErrAuthRequired indicates that no auth provider was selected.
|
||||
ErrAuthRequired = errors.New(
|
||||
"auth provider required (use -auth jitsi, -auth telemost, -auth jazz, -auth wbstream or -auth none)")
|
||||
// ErrURLRequired indicates that -url must be provided when the auth provider has no default URL.
|
||||
ErrURLRequired = errors.New("SFU URL required (use -url wss://...)")
|
||||
"auth provider required (set auth.provider to jitsi, telemost, jazz, wbstream or none)")
|
||||
// ErrURLRequired indicates that auth.url must be provided when the auth provider has no default URL.
|
||||
ErrURLRequired = errors.New("SFU URL required (set auth.url)")
|
||||
// ErrUnsupportedCarrier indicates that carrier is not registered.
|
||||
ErrUnsupportedCarrier = errors.New("unsupported carrier")
|
||||
// ErrUnsupportedLink indicates that link is not registered.
|
||||
@@ -57,88 +78,119 @@ var (
|
||||
ErrUnsupportedTransport = errors.New("unsupported transport")
|
||||
|
||||
// ErrLinkRequired indicates that link is not provided.
|
||||
ErrLinkRequired = errors.New("link required (use -link direct)")
|
||||
ErrLinkRequired = errors.New("link required (set link to direct)")
|
||||
// ErrTransportRequired indicates that transport is not provided.
|
||||
ErrTransportRequired = errors.New(
|
||||
"transport required (use -transport datachannel, -transport videochannel, " +
|
||||
"-transport seichannel or -transport vp8channel)")
|
||||
"transport required (set transport to datachannel, videochannel, seichannel or vp8channel)")
|
||||
// ErrKeyRequired indicates that encryption key is not provided.
|
||||
ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
ErrKeyRequired = errors.New("key required (set crypto.key)")
|
||||
// ErrDNSServerRequired indicates that dns server is not provided.
|
||||
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
|
||||
ErrDNSServerRequired = errors.New("dns server required (set net.dns)")
|
||||
|
||||
// ErrVideoWidthRequired indicates that video width is required for videochannel.
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (set video.width)")
|
||||
// ErrVideoHeightRequired indicates that video height is required for videochannel.
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (set video.height)")
|
||||
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (set video.fps)")
|
||||
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
|
||||
ErrVideoBitrateRequired = errors.New(
|
||||
"video bitrate required for videochannel (use -video-bitrate)")
|
||||
"video bitrate required for videochannel (set video.bitrate)")
|
||||
// ErrVideoHWRequired indicates that video hardware acceleration is required.
|
||||
ErrVideoHWRequired = errors.New(
|
||||
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
|
||||
"video hardware acceleration required for videochannel (set video.hw to none or nvenc)")
|
||||
// ErrVideoCodecInvalid indicates that the video codec is not valid.
|
||||
ErrVideoCodecInvalid = errors.New(
|
||||
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
|
||||
"invalid video codec for videochannel (set video.codec to qrcode or tile)")
|
||||
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
|
||||
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
|
||||
ErrTileCodecDimensions = errors.New("tile codec requires video.width: 1080 and video.height: 1080")
|
||||
|
||||
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (set vp8.fps)")
|
||||
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
|
||||
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
|
||||
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (set vp8.batch_size)")
|
||||
// ErrSEIFPSRequired indicates that seichannel fps is required.
|
||||
ErrSEIFPSRequired = errors.New("fps required for seichannel (use -fps)")
|
||||
ErrSEIFPSRequired = errors.New("fps required for seichannel (set sei.fps)")
|
||||
// ErrSEIBatchSizeRequired indicates that seichannel batch size is required.
|
||||
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (use -batch)")
|
||||
ErrSEIBatchSizeRequired = errors.New("batch size required for seichannel (set sei.batch_size)")
|
||||
// ErrSEIFragmentSizeRequired indicates that seichannel fragment size is required.
|
||||
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (use -frag)")
|
||||
ErrSEIFragmentSizeRequired = errors.New("fragment size required for seichannel (set sei.fragment_size)")
|
||||
// ErrSEIAckTimeoutRequired indicates that seichannel ack timeout is required.
|
||||
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (use -ack-ms)")
|
||||
ErrSEIAckTimeoutRequired = errors.New("ack timeout required for seichannel (set sei.ack_timeout_ms)")
|
||||
|
||||
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (set socks.host)")
|
||||
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (set socks.port)")
|
||||
// ErrSOCKSAuthRequired indicates that a non-loopback SOCKS listener requires authentication.
|
||||
ErrSOCKSAuthRequired = errors.New(
|
||||
"socks auth required when binding outside loopback (set socks.user and socks.pass)")
|
||||
|
||||
// ErrLivenessIntervalInvalid indicates that liveness.interval is not a positive duration.
|
||||
ErrLivenessIntervalInvalid = errors.New(
|
||||
"invalid liveness interval (set liveness.interval to a duration > 0)")
|
||||
// ErrLivenessTimeoutInvalid indicates that liveness.timeout is not a positive duration.
|
||||
ErrLivenessTimeoutInvalid = errors.New(
|
||||
"invalid liveness timeout (set liveness.timeout to a duration > 0)")
|
||||
// ErrLivenessFailuresInvalid indicates that liveness.failures is not positive.
|
||||
ErrLivenessFailuresInvalid = errors.New(
|
||||
"invalid liveness failures (set liveness.failures to a value > 0)")
|
||||
// ErrLifecycleMaxSessionDurationInvalid indicates that lifecycle.max_session_duration is not a positive duration.
|
||||
ErrLifecycleMaxSessionDurationInvalid = errors.New(
|
||||
"invalid max session duration (set lifecycle.max_session_duration to a duration > 0)")
|
||||
// ErrTrafficMaxPayloadSizeInvalid indicates that traffic.max_payload_size is not valid.
|
||||
ErrTrafficMaxPayloadSizeInvalid = errors.New(
|
||||
"invalid traffic max payload size (set traffic.max_payload_size to 0 or a value above crypto overhead)")
|
||||
// ErrTrafficMinDelayInvalid indicates that traffic.min_delay is not a non-negative duration.
|
||||
ErrTrafficMinDelayInvalid = errors.New(
|
||||
"invalid traffic min delay (set traffic.min_delay to a duration >= 0)")
|
||||
// ErrTrafficMaxDelayInvalid indicates that traffic.max_delay is not a non-negative duration.
|
||||
ErrTrafficMaxDelayInvalid = errors.New(
|
||||
"invalid traffic max delay (set traffic.max_delay to a duration >= 0 and >= traffic.min_delay)")
|
||||
)
|
||||
|
||||
// Config holds runtime session settings.
|
||||
type Config struct {
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Amount int
|
||||
Mode string
|
||||
Link string
|
||||
Transport string
|
||||
Auth string
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
RoomID string
|
||||
KeyHex string
|
||||
SOCKSHost string
|
||||
SOCKSPort int
|
||||
SOCKSUser string
|
||||
SOCKSPass string
|
||||
DNSServer string
|
||||
SOCKSProxyAddr string
|
||||
SOCKSProxyPort int
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
LivenessInterval string
|
||||
LivenessTimeout string
|
||||
LivenessFailures int
|
||||
MaxSessionDuration string
|
||||
TrafficMaxPayloadSize int
|
||||
TrafficMinDelay string
|
||||
TrafficMaxDelay string
|
||||
Amount int
|
||||
}
|
||||
|
||||
// RegisterDefaults registers built-in carriers and transports.
|
||||
@@ -180,6 +232,94 @@ func ApplyAuthDefaults(cfg Config) (Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ApplyTransportDefaults fills documented transport defaults without changing core routing fields.
|
||||
func ApplyTransportDefaults(cfg Config) Config {
|
||||
switch cfg.Transport {
|
||||
case transportVideo:
|
||||
return applyVideoDefaults(cfg)
|
||||
case transportVP8:
|
||||
return applyVP8Defaults(cfg)
|
||||
case transportSEI:
|
||||
return applySEIDefaults(cfg)
|
||||
default:
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyLivenessDefaults fills documented control-stream liveness defaults.
|
||||
func ApplyLivenessDefaults(cfg Config) Config {
|
||||
if cfg.LivenessInterval == "" {
|
||||
cfg.LivenessInterval = control.DefaultInterval.String()
|
||||
}
|
||||
if cfg.LivenessTimeout == "" {
|
||||
cfg.LivenessTimeout = control.DefaultTimeout.String()
|
||||
}
|
||||
if cfg.LivenessFailures == 0 {
|
||||
cfg.LivenessFailures = control.DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyVideoDefaults(cfg Config) Config {
|
||||
if cfg.VideoCodec == "" {
|
||||
cfg.VideoCodec = videoCodecQRCode
|
||||
}
|
||||
if cfg.VideoCodec == videoCodecTile {
|
||||
if cfg.VideoWidth == 0 {
|
||||
cfg.VideoWidth = 1080
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
cfg.VideoHeight = 1080
|
||||
}
|
||||
} else {
|
||||
if cfg.VideoWidth == 0 {
|
||||
cfg.VideoWidth = defaultVideoWidth
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
cfg.VideoHeight = defaultVideoHeight
|
||||
}
|
||||
}
|
||||
if cfg.VideoFPS == 0 {
|
||||
cfg.VideoFPS = defaultVideoFPS
|
||||
}
|
||||
if cfg.VideoBitrate == "" {
|
||||
cfg.VideoBitrate = defaultVideoBitrate
|
||||
}
|
||||
if cfg.VideoHW == "" {
|
||||
cfg.VideoHW = defaultVideoHW
|
||||
}
|
||||
if cfg.VideoQRRecovery == "" {
|
||||
cfg.VideoQRRecovery = defaultVideoQRRecovery
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyVP8Defaults(cfg Config) Config {
|
||||
if cfg.VP8FPS == 0 {
|
||||
cfg.VP8FPS = defaultVP8FPS
|
||||
}
|
||||
if cfg.VP8BatchSize == 0 {
|
||||
cfg.VP8BatchSize = defaultVP8BatchSize
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applySEIDefaults(cfg Config) Config {
|
||||
if cfg.SEIFPS == 0 {
|
||||
cfg.SEIFPS = defaultSEIFPS
|
||||
}
|
||||
if cfg.SEIBatchSize == 0 {
|
||||
cfg.SEIBatchSize = defaultSEIBatchSize
|
||||
}
|
||||
if cfg.SEIFragmentSize == 0 {
|
||||
cfg.SEIFragmentSize = defaultSEIFragmentSize
|
||||
}
|
||||
if cfg.SEIAckTimeoutMS == 0 {
|
||||
cfg.SEIAckTimeoutMS = defaultSEIAckTimeoutMS
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Validate verifies that the runtime config refers to registered components and all required fields are present.
|
||||
func Validate(cfg Config) error {
|
||||
if err := validateMode(cfg); err != nil {
|
||||
@@ -200,6 +340,15 @@ func Validate(cfg Config) error {
|
||||
if err := validateTransportConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateLivenessConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateLifecycleConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTrafficConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateModeConfig(cfg)
|
||||
}
|
||||
|
||||
@@ -333,13 +482,163 @@ func validateModeConfig(cfg Config) error {
|
||||
if cfg.SOCKSPort == 0 {
|
||||
return ErrSOCKSPortRequired
|
||||
}
|
||||
if !isLoopbackListenHost(cfg.SOCKSHost) && (cfg.SOCKSUser == "" || cfg.SOCKSPass == "") {
|
||||
return ErrSOCKSAuthRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLivenessConfig(cfg Config) error {
|
||||
if _, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
if _, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
if cfg.LivenessFailures < 0 {
|
||||
return ErrLivenessFailuresInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLifecycleConfig(cfg Config) error {
|
||||
if _, err := maxSessionDuration(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLivenessDuration(value string, def time.Duration) (time.Duration, error) {
|
||||
if value == "" {
|
||||
return def, nil
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("duration must be > 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func livenessConfig(cfg Config) (control.Config, error) {
|
||||
interval, err := parseLivenessDuration(cfg.LivenessInterval, control.DefaultInterval)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessIntervalInvalid, err)
|
||||
}
|
||||
timeout, err := parseLivenessDuration(cfg.LivenessTimeout, control.DefaultTimeout)
|
||||
if err != nil {
|
||||
return control.Config{}, fmt.Errorf("%w: %v", ErrLivenessTimeoutInvalid, err)
|
||||
}
|
||||
failures := cfg.LivenessFailures
|
||||
if failures == 0 {
|
||||
failures = control.DefaultFailures
|
||||
}
|
||||
if failures < 0 {
|
||||
return control.Config{}, ErrLivenessFailuresInvalid
|
||||
}
|
||||
return control.Config{Interval: interval, Timeout: timeout, Failures: failures}, nil
|
||||
}
|
||||
|
||||
func maxSessionDuration(cfg Config) (time.Duration, error) {
|
||||
if cfg.MaxSessionDuration == "" {
|
||||
return 0, nil
|
||||
}
|
||||
d, err := time.ParseDuration(cfg.MaxSessionDuration)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %v", ErrLifecycleMaxSessionDurationInvalid, err)
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, ErrLifecycleMaxSessionDurationInvalid
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func validateTrafficConfig(cfg Config) error {
|
||||
_, err := trafficConfig(cfg)
|
||||
return err
|
||||
}
|
||||
|
||||
func trafficConfig(cfg Config) (transport.TrafficConfig, error) {
|
||||
if cfg.TrafficMaxPayloadSize < 0 || (cfg.TrafficMaxPayloadSize > 0 &&
|
||||
cfg.TrafficMaxPayloadSize <= crypto.WireOverhead) {
|
||||
return transport.TrafficConfig{}, ErrTrafficMaxPayloadSizeInvalid
|
||||
}
|
||||
minDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMinDelay)
|
||||
if err != nil {
|
||||
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMinDelayInvalid, err)
|
||||
}
|
||||
maxDelay, err := parseOptionalNonNegativeDuration(cfg.TrafficMaxDelay)
|
||||
if err != nil {
|
||||
return transport.TrafficConfig{}, fmt.Errorf("%w: %v", ErrTrafficMaxDelayInvalid, err)
|
||||
}
|
||||
if maxDelay > 0 && maxDelay < minDelay {
|
||||
return transport.TrafficConfig{}, ErrTrafficMaxDelayInvalid
|
||||
}
|
||||
return transport.TrafficConfig{
|
||||
MaxPayloadSize: cfg.TrafficMaxPayloadSize,
|
||||
MinDelay: minDelay,
|
||||
MaxDelay: maxDelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOptionalNonNegativeDuration(value string) (time.Duration, error) {
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, fmt.Errorf("duration must be >= 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func isLoopbackListenHost(host string) bool {
|
||||
if host == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
// Run starts the configured mode.
|
||||
func Run(ctx context.Context, cfg Config) error {
|
||||
cfg = ApplyTransportDefaults(cfg)
|
||||
cfg = ApplyLivenessDefaults(cfg)
|
||||
roomURL := cfg.RoomID
|
||||
liveness, err := livenessConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
maxDuration, err := maxSessionDuration(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
traffic, err := trafficConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
run := func(ctx context.Context) error {
|
||||
return runOnce(ctx, cfg, roomURL, liveness, traffic)
|
||||
}
|
||||
if maxDuration > 0 {
|
||||
return runWithSessionRotation(ctx, maxDuration, run)
|
||||
}
|
||||
return run(ctx)
|
||||
}
|
||||
|
||||
func runOnce(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
roomURL string,
|
||||
liveness control.Config,
|
||||
traffic transport.TrafficConfig,
|
||||
) error {
|
||||
switch cfg.Mode {
|
||||
case modeSRV:
|
||||
if err := server.Run(ctx, server.Config{
|
||||
@@ -370,6 +669,8 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
Traffic: traffic,
|
||||
OnSessionOpen: func(sessionID, deviceID string, claims map[string]any) {
|
||||
logger.Infof("session opened: id=%s device=%s claims=%v", sessionID, deviceID, claims)
|
||||
},
|
||||
@@ -413,6 +714,8 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
Liveness: liveness,
|
||||
Traffic: traffic,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("client: %w", err)
|
||||
}
|
||||
@@ -422,6 +725,52 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
func runWithSessionRotation(ctx context.Context, maxDuration time.Duration, run func(context.Context) error) error {
|
||||
for cycle := 1; ; cycle++ {
|
||||
currentCycle := cycle
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
var rotated atomic.Bool
|
||||
timer := time.AfterFunc(maxDuration, func() {
|
||||
rotated.Store(true)
|
||||
logger.Infof("session max duration reached: duration=%s cycle=%d", maxDuration, currentCycle)
|
||||
cancel()
|
||||
})
|
||||
|
||||
err := run(runCtx)
|
||||
cancel()
|
||||
timer.Stop()
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if rotated.Load() {
|
||||
if err != nil {
|
||||
logger.Warnf("session rotation ended with error: cycle=%d err=%v", currentCycle, err)
|
||||
}
|
||||
logger.Infof("session rotation restarting: next_cycle=%d", currentCycle+1)
|
||||
if err := waitSessionRestart(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("session ended cleanly with lifecycle rotation enabled: next_cycle=%d", currentCycle+1)
|
||||
if err := waitSessionRestart(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitSessionRestart(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(sessionRestartDelay):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateGen validates that the config contains enough fields to run gen mode.
|
||||
func ValidateGen(cfg Config) error {
|
||||
if cfg.Auth == "" {
|
||||
|
||||
@@ -3,9 +3,136 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
)
|
||||
|
||||
func TestApplyTransportDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in Config
|
||||
want Config
|
||||
}{
|
||||
{
|
||||
name: "vp8",
|
||||
in: Config{Transport: transportVP8},
|
||||
want: Config{Transport: transportVP8, VP8FPS: 25, VP8BatchSize: 1},
|
||||
},
|
||||
{
|
||||
name: "sei",
|
||||
in: Config{Transport: transportSEI},
|
||||
want: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 60,
|
||||
SEIBatchSize: 64,
|
||||
SEIFragmentSize: 900,
|
||||
SEIAckTimeoutMS: 2000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "video qrcode",
|
||||
in: Config{Transport: transportVideo},
|
||||
want: Config{
|
||||
Transport: transportVideo,
|
||||
VideoWidth: 1920,
|
||||
VideoHeight: 1080,
|
||||
VideoFPS: 30,
|
||||
VideoBitrate: "2M",
|
||||
VideoHW: "none",
|
||||
VideoQRRecovery: "low",
|
||||
VideoCodec: videoCodecQRCode,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "video tile dimensions",
|
||||
in: Config{Transport: transportVideo, VideoCodec: videoCodecTile},
|
||||
want: Config{
|
||||
Transport: transportVideo,
|
||||
VideoWidth: 1080,
|
||||
VideoHeight: 1080,
|
||||
VideoFPS: 30,
|
||||
VideoBitrate: "2M",
|
||||
VideoHW: "none",
|
||||
VideoQRRecovery: "low",
|
||||
VideoCodec: videoCodecTile,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keeps explicit values",
|
||||
in: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 10,
|
||||
SEIBatchSize: 2,
|
||||
SEIFragmentSize: 300,
|
||||
SEIAckTimeoutMS: 1500,
|
||||
},
|
||||
want: Config{
|
||||
Transport: transportSEI,
|
||||
SEIFPS: 10,
|
||||
SEIBatchSize: 2,
|
||||
SEIFragmentSize: 300,
|
||||
SEIAckTimeoutMS: 1500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ApplyTransportDefaults(tt.in)
|
||||
if got != tt.want {
|
||||
t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyLivenessDefaults(t *testing.T) {
|
||||
got := ApplyLivenessDefaults(Config{})
|
||||
if got.LivenessInterval != control.DefaultInterval.String() {
|
||||
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
|
||||
}
|
||||
if got.LivenessTimeout != control.DefaultTimeout.String() {
|
||||
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
|
||||
}
|
||||
if got.LivenessFailures != control.DefaultFailures {
|
||||
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
|
||||
}
|
||||
|
||||
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
|
||||
if got := ApplyLivenessDefaults(explicit); got != explicit {
|
||||
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) {
|
||||
oldRestartDelay := sessionRestartDelay
|
||||
sessionRestartDelay = time.Millisecond
|
||||
t.Cleanup(func() { sessionRestartDelay = oldRestartDelay })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var calls atomic.Int32
|
||||
err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error {
|
||||
if calls.Add(1) >= 2 {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runWithSessionRotation() error = %v", err)
|
||||
}
|
||||
if got := calls.Load(); got < 2 {
|
||||
t.Fatalf("run calls = %d, want at least 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:maintidx // table-driven validation test naturally has many cases
|
||||
func TestValidate(t *testing.T) {
|
||||
RegisterDefaults()
|
||||
@@ -310,6 +437,148 @@ func TestValidate(t *testing.T) {
|
||||
}(),
|
||||
want: ErrSOCKSPortRequired,
|
||||
},
|
||||
{
|
||||
name: "cnc rejects unauthenticated wildcard socks bind",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "0.0.0.0"
|
||||
cfg.SOCKSPort = 1080
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrSOCKSAuthRequired,
|
||||
},
|
||||
{
|
||||
name: "cnc allows authenticated wildcard socks bind",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "0.0.0.0"
|
||||
cfg.SOCKSPort = 1080
|
||||
cfg.SOCKSUser = "user"
|
||||
cfg.SOCKSPass = "pass"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "cnc allows localhost socks bind without auth",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.Mode = modeCNC
|
||||
cfg.SOCKSHost = "localhost"
|
||||
cfg.SOCKSPort = 1080
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "liveness rejects bad interval",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessInterval = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessIntervalInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects zero timeout",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessTimeout = "0s"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessTimeoutInvalid,
|
||||
},
|
||||
{
|
||||
name: "liveness rejects negative failures",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.LivenessFailures = -1
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLivenessFailuresInvalid,
|
||||
},
|
||||
{
|
||||
name: "lifecycle accepts max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "1h"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "lifecycle rejects bad max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLifecycleMaxSessionDurationInvalid,
|
||||
},
|
||||
{
|
||||
name: "lifecycle rejects zero max session duration",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.MaxSessionDuration = "0s"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrLifecycleMaxSessionDurationInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic accepts shaping",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = 4096
|
||||
cfg.TrafficMinDelay = "5ms"
|
||||
cfg.TrafficMaxDelay = "30ms"
|
||||
return cfg
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "traffic rejects negative max payload",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = -1
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxPayloadSizeInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects payload smaller than crypto overhead",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxPayloadSize = crypto.WireOverhead
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxPayloadSizeInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects bad min delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMinDelay = "nope"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMinDelayInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects negative max delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMaxDelay = "-1ms"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxDelayInvalid,
|
||||
},
|
||||
{
|
||||
name: "traffic rejects max delay below min delay",
|
||||
cfg: func() Config {
|
||||
cfg := base
|
||||
cfg.TrafficMinDelay = "30ms"
|
||||
cfg.TrafficMaxDelay = "5ms"
|
||||
return cfg
|
||||
}(),
|
||||
want: ErrTrafficMaxDelayInvalid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,9 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
@@ -122,7 +120,7 @@ func createMeeting(ctx context.Context, headers map[string]string) (*createRespo
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, statusError(errCreateRoomFailed, resp)
|
||||
return nil, protect.StatusError(errCreateRoomFailed, resp, 1024)
|
||||
}
|
||||
|
||||
var res createResponse
|
||||
@@ -174,7 +172,7 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
|
||||
defer func() { _ = preResp.Body.Close() }()
|
||||
|
||||
if preResp.StatusCode != http.StatusOK {
|
||||
return "", statusError(errPreconnectFailed, preResp)
|
||||
return "", protect.StatusError(errPreconnectFailed, preResp, 1024)
|
||||
}
|
||||
|
||||
var preconnectResp struct {
|
||||
@@ -186,15 +184,6 @@ func preconnect(ctx context.Context, roomID, password string, headers map[string
|
||||
return preconnectResp.ConnectorURL, nil
|
||||
}
|
||||
|
||||
func statusError(base error, resp *http.Response) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
bodyText := strings.TrimSpace(string(body))
|
||||
if bodyText == "" {
|
||||
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
|
||||
}
|
||||
|
||||
func joinRoom(ctx context.Context, roomID, password string) (*roomInfo, error) {
|
||||
headers := anonymousHeaders()
|
||||
connectorURL, err := preconnect(ctx, roomID, password, headers)
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
@@ -69,8 +68,7 @@ func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*Conne
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
|
||||
return nil, protect.StatusError(ErrAPI, resp, 4096)
|
||||
}
|
||||
|
||||
var info ConnectionInfo
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/protect"
|
||||
@@ -84,8 +83,7 @@ func registerGuest(ctx context.Context, displayName string) (string, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("%w: %d %s", errGuestRegister, resp.StatusCode, b)
|
||||
return "", protect.StatusError(errGuestRegister, resp, 4096)
|
||||
}
|
||||
|
||||
var res guestRegisterResponse
|
||||
@@ -122,8 +120,7 @@ func createRoom(ctx context.Context, accessToken string) (string, error) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("%w: %d %s", errCreateRoom, resp.StatusCode, b)
|
||||
return "", protect.StatusError(errCreateRoom, resp, 4096)
|
||||
}
|
||||
|
||||
var res createRoomResponse
|
||||
@@ -151,8 +148,7 @@ func joinRoom(ctx context.Context, accessToken, roomID string) error {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%w: %d %s", errJoinRoom, resp.StatusCode, b)
|
||||
return protect.StatusError(errJoinRoom, resp, 4096)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -180,8 +176,7 @@ func getToken(ctx context.Context, accessToken, roomID, displayName string) (tok
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return tokenResponse{}, fmt.Errorf("%w: %d %s", errGetToken, resp.StatusCode, b)
|
||||
return tokenResponse{}, protect.StatusError(errGetToken, resp, 4096)
|
||||
}
|
||||
|
||||
var res tokenResponse
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
@@ -54,7 +56,12 @@ type Client struct {
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStrm *smux.Stream
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reconnectMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
health control.Status
|
||||
onHealth HealthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
claims map[string]any
|
||||
@@ -63,6 +70,9 @@ type Client struct {
|
||||
socksPass string
|
||||
}
|
||||
|
||||
// HealthFunc is called when the client control health snapshot changes.
|
||||
type HealthFunc func(control.Status)
|
||||
|
||||
// Config holds runtime configuration for [Run] and [RunWithReady].
|
||||
type Config struct {
|
||||
Link string
|
||||
@@ -93,6 +103,8 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
Traffic transport.TrafficConfig
|
||||
|
||||
// DeviceID overrides the persistent client-side device identifier. Leave
|
||||
// empty to derive one from DeviceIDPath (or generate a random one if both
|
||||
@@ -106,6 +118,9 @@ type Config struct {
|
||||
// Claims is sent to the server in CLIENT_HELLO and forwarded verbatim to
|
||||
// the server's AuthHook. Free-form key/value bag for plan, user, region, etc.
|
||||
Claims map[string]any
|
||||
|
||||
// OnHealth receives liveness/reconnect status updates. Nil means no-op.
|
||||
OnHealth HealthFunc
|
||||
}
|
||||
|
||||
// Run starts the client with the given configuration.
|
||||
@@ -135,6 +150,7 @@ func RunWithReady(ctx context.Context, cfg Config, onReady func()) error {
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksUser: cfg.SOCKSUser,
|
||||
socksPass: cfg.SOCKSPass,
|
||||
onHealth: cfg.OnHealth,
|
||||
}
|
||||
|
||||
// shutdown is registered BEFORE bringUpLink so we always close any
|
||||
@@ -202,6 +218,7 @@ func (c *Client) bringUpLink(
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
@@ -217,7 +234,9 @@ func (c *Client) bringUpLink(
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
c.handleReconnect()
|
||||
if !c.handleReconnect(ctx, cfg, cancel, "carrier") {
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
|
||||
if err := ln.Connect(ctx); err != nil {
|
||||
@@ -225,7 +244,7 @@ func (c *Client) bringUpLink(
|
||||
}
|
||||
|
||||
c.conn = muxconn.New(ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
sess, err := smux.Client(c.conn, smuxConfig(linkMaxPayload(ln)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("smux client: %w", err)
|
||||
}
|
||||
@@ -243,14 +262,16 @@ func (c *Client) bringUpLink(
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.recordSession(sid)
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
|
||||
go ln.WatchConnection(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// openControlStream opens stream #1 on sess and performs the handshake.
|
||||
// The stream stays open for the lifetime of the smux session — the server
|
||||
// holds it parked, and it would carry future control messages.
|
||||
// The stream stays open for the lifetime of the smux session and carries
|
||||
// post-handshake control messages.
|
||||
func openControlStream(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
@@ -314,11 +335,17 @@ func resolveDeviceID(deviceID, path string) (string, error) {
|
||||
}
|
||||
|
||||
// smuxConfig returns the tuned smux config used on both ends.
|
||||
func smuxConfig() *smux.Config {
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
@@ -326,8 +353,20 @@ func smuxConfig() *smux.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect() {
|
||||
logger.Infof("client link reconnect - tearing down smux session")
|
||||
func linkMaxPayload(ln link.Link) int {
|
||||
provider, ok := ln.(link.FeaturesProvider)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return provider.Features().MaxPayloadSize
|
||||
}
|
||||
|
||||
func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context.CancelFunc, reason string) bool {
|
||||
c.reconnectMu.Lock()
|
||||
defer c.reconnectMu.Unlock()
|
||||
|
||||
c.recordReconnect()
|
||||
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
|
||||
|
||||
// Install a fresh muxconn immediately so onData never hits nil while
|
||||
// the old session is being torn down. tryReopenSession will swap it
|
||||
@@ -336,14 +375,19 @@ func (c *Client) handleReconnect() {
|
||||
|
||||
c.sessMu.Lock()
|
||||
oldControl := c.controlStrm
|
||||
oldControlStop := c.controlStop
|
||||
oldSess := c.session
|
||||
oldConn := c.conn
|
||||
c.conn = newConn
|
||||
c.session = nil
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.sessionID = ""
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldControl != nil {
|
||||
_ = oldControl.Close()
|
||||
}
|
||||
@@ -364,15 +408,26 @@ func (c *Client) handleReconnect() {
|
||||
attemptDelay = 300 * time.Millisecond
|
||||
)
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if c.tryReopenSession(attempt) {
|
||||
return
|
||||
logger.Infof("client reconnect attempt=%d reason=%s", attempt, reason)
|
||||
if c.tryReopenSession(ctx, cfg, cancel, attempt) {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(attemptDelay):
|
||||
}
|
||||
time.Sleep(attemptDelay)
|
||||
}
|
||||
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) tryReopenSession(attempt int) bool {
|
||||
func (c *Client) tryReopenSession(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
attempt int,
|
||||
) bool {
|
||||
conn := muxconn.New(c.ln, c.cipher)
|
||||
|
||||
c.sessMu.Lock()
|
||||
@@ -383,7 +438,7 @@ func (c *Client) tryReopenSession(attempt int) bool {
|
||||
_ = old.Close()
|
||||
}
|
||||
|
||||
sess, err := smux.Client(conn, smuxConfig())
|
||||
sess, err := smux.Client(conn, smuxConfig(linkMaxPayload(c.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
|
||||
return false
|
||||
@@ -400,19 +455,138 @@ func (c *Client) tryReopenSession(attempt int) bool {
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
c.recordSession(sid)
|
||||
c.startControlLoop(ctx, cfg, cancel, control)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) startControlLoop(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
cancel context.CancelFunc,
|
||||
stream *smux.Stream,
|
||||
) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
c.sessMu.Lock()
|
||||
c.controlStop = stop
|
||||
c.sessMu.Unlock()
|
||||
|
||||
liveness := cfg.Liveness
|
||||
onPong := liveness.OnPong
|
||||
onMissedPong := liveness.OnMissedPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
c.sessMu.RLock()
|
||||
sid := c.sessionID
|
||||
c.sessMu.RUnlock()
|
||||
c.recordPong(h)
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnMissedPong = func(missed int) {
|
||||
c.recordMissed(missed)
|
||||
logger.Warnf("control missed pong on client: missed_pongs=%d", missed)
|
||||
if onMissedPong != nil {
|
||||
onMissedPong(missed)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
c.recordUnhealthy(missed)
|
||||
logger.Warnf("control stream unhealthy on client: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("client control stream ended: %v", err)
|
||||
}
|
||||
if !c.handleReconnect(ctx, cfg, cancel, "liveness") {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Status returns the latest client-side control health snapshot.
|
||||
func (c *Client) Status() control.Status {
|
||||
c.healthMu.RLock()
|
||||
defer c.healthMu.RUnlock()
|
||||
return c.health
|
||||
}
|
||||
|
||||
func (c *Client) recordSession(sessionID string) {
|
||||
c.healthMu.Lock()
|
||||
c.health.SessionID = sessionID
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordPong(h control.Health) {
|
||||
c.healthMu.Lock()
|
||||
c.health.LastPong = h.LastSeen
|
||||
c.health.LastRTT = h.RTT
|
||||
c.health.MissedPongs = 0
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordMissed(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordUnhealthy(missed int) {
|
||||
c.healthMu.Lock()
|
||||
c.health.MissedPongs = missed
|
||||
c.health.UnhealthyEvents++
|
||||
c.health.LastUnhealthy = time.Now()
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) recordReconnect() {
|
||||
c.healthMu.Lock()
|
||||
c.health.Reconnects++
|
||||
status := c.health
|
||||
c.healthMu.Unlock()
|
||||
c.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (c *Client) notifyHealth(status control.Status) {
|
||||
if c.onHealth != nil {
|
||||
c.onHealth(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
control := c.controlStrm
|
||||
controlStop := c.controlStop
|
||||
sess := c.session
|
||||
conn := c.conn
|
||||
c.controlStrm = nil
|
||||
c.controlStop = nil
|
||||
c.session = nil
|
||||
c.conn = nil
|
||||
c.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -48,6 +49,11 @@ func TestSmuxConfig(t *testing.T) {
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
|
||||
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocks5Handshake(t *testing.T) {
|
||||
@@ -517,3 +523,96 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
|
||||
t.Fatal("shutdown() did not close link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
peerStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
peerStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
stream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
peerStream := <-peerStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
got := make(chan control.Health, 1)
|
||||
c := &Client{sessionID: "sid-control"}
|
||||
c.recordSession("sid-control")
|
||||
c.startControlLoop(ctx, Config{
|
||||
Liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}, cancel, stream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, peerStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
status := c.Status()
|
||||
if status.SessionID != "sid-control" {
|
||||
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
|
||||
}
|
||||
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
c := &Client{onHealth: func(control.Status) { updates++ }}
|
||||
c.recordSession("sid-1")
|
||||
c.recordMissed(2)
|
||||
c.recordUnhealthy(3)
|
||||
c.recordReconnect()
|
||||
|
||||
status := c.Status()
|
||||
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
|
||||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
if updates != 4 {
|
||||
t.Fatalf("health updates = %d, want 4", updates)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
// Package config loads olcrtc runtime configuration from YAML files.
|
||||
//
|
||||
// The YAML schema mirrors [session.Config]. Fields left unset in the file
|
||||
// remain at their zero value, allowing CLI flags to fill them in. Use
|
||||
// [Apply] to merge a parsed [File] onto an existing [session.Config];
|
||||
// non-zero fields in the session config (typically populated from CLI flags)
|
||||
// take precedence over the YAML values.
|
||||
// remain at their zero value. Use [Apply] to map a parsed [File] onto an
|
||||
// existing [session.Config]; non-zero fields in the session config take
|
||||
// precedence over the YAML values.
|
||||
//
|
||||
//nolint:tagliatelle // YAML keys are the documented config file schema.
|
||||
package config
|
||||
@@ -13,31 +12,68 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
|
||||
var ErrConfigNotFound = errors.New("config file not found")
|
||||
var (
|
||||
// ErrConfigNotFound is returned when a config file path is set but the file does not exist.
|
||||
ErrConfigNotFound = errors.New("config file not found")
|
||||
// ErrCryptoKeyConflict is returned when both inline and file-backed keys are configured.
|
||||
ErrCryptoKeyConflict = errors.New("crypto.key and crypto.key_file cannot both be set")
|
||||
// ErrCryptoKeyFileEmpty is returned when crypto.key_file points to an empty file.
|
||||
ErrCryptoKeyFileEmpty = errors.New("crypto key file is empty")
|
||||
)
|
||||
|
||||
// File is the on-disk YAML schema.
|
||||
type File struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Gen Gen `yaml:"gen"`
|
||||
Data string `yaml:"data"`
|
||||
Debug bool `yaml:"debug"`
|
||||
FFmpeg string `yaml:"ffmpeg"`
|
||||
Mode string `yaml:"mode"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
Lifecycle Lifecycle `yaml:"lifecycle"`
|
||||
Traffic Traffic `yaml:"traffic"`
|
||||
Gen Gen `yaml:"gen"`
|
||||
Profiles []Profile `yaml:"profiles"`
|
||||
Failover Failover `yaml:"failover"`
|
||||
Data string `yaml:"data"`
|
||||
Debug bool `yaml:"debug"`
|
||||
FFmpeg string `yaml:"ffmpeg"`
|
||||
}
|
||||
|
||||
// Profile is a failover entry that overrides top-level runtime fields.
|
||||
type Profile struct {
|
||||
Name string `yaml:"name"`
|
||||
Link string `yaml:"link"`
|
||||
Auth Auth `yaml:"auth"`
|
||||
Room Room `yaml:"room"`
|
||||
Crypto Crypto `yaml:"crypto"`
|
||||
Net Net `yaml:"net"`
|
||||
SOCKS SOCKS `yaml:"socks"`
|
||||
Engine Engine `yaml:"engine"`
|
||||
Video Video `yaml:"video"`
|
||||
VP8 VP8 `yaml:"vp8"`
|
||||
SEI SEI `yaml:"sei"`
|
||||
Liveness Liveness `yaml:"liveness"`
|
||||
Lifecycle Lifecycle `yaml:"lifecycle"`
|
||||
Traffic Traffic `yaml:"traffic"`
|
||||
}
|
||||
|
||||
// Failover controls ordered profile failover.
|
||||
type Failover struct {
|
||||
RetryDelay string `yaml:"retry_delay"`
|
||||
MaxCycles int `yaml:"max_cycles"`
|
||||
}
|
||||
|
||||
// Auth selects the auth provider.
|
||||
@@ -52,7 +88,8 @@ type Room struct {
|
||||
|
||||
// Crypto holds the shared secret used to authenticate and encrypt the tunnel.
|
||||
type Crypto struct {
|
||||
Key string `yaml:"key"` // 64-char hex (32 bytes)
|
||||
Key string `yaml:"key"` // 64-char hex (32 bytes)
|
||||
KeyFile string `yaml:"key_file"` // path to a file containing crypto.key
|
||||
}
|
||||
|
||||
// Net groups network and transport selection.
|
||||
@@ -106,6 +143,25 @@ type SEI struct {
|
||||
AckTimeoutMS int `yaml:"ack_timeout_ms"`
|
||||
}
|
||||
|
||||
// Liveness tunes the post-handshake control stream ping/pong checks.
|
||||
type Liveness struct {
|
||||
Interval string `yaml:"interval"`
|
||||
Timeout string `yaml:"timeout"`
|
||||
Failures int `yaml:"failures"`
|
||||
}
|
||||
|
||||
// Lifecycle controls planned session rebuilds.
|
||||
type Lifecycle struct {
|
||||
MaxSessionDuration string `yaml:"max_session_duration"`
|
||||
}
|
||||
|
||||
// Traffic controls optional reliability-oriented send shaping.
|
||||
type Traffic struct {
|
||||
MaxPayloadSize int `yaml:"max_payload_size"`
|
||||
MinDelay string `yaml:"min_delay"`
|
||||
MaxDelay string `yaml:"max_delay"`
|
||||
}
|
||||
|
||||
// Gen controls room-generation mode.
|
||||
type Gen struct {
|
||||
Amount int `yaml:"amount"`
|
||||
@@ -125,9 +181,63 @@ func Load(path string) (File, error) {
|
||||
if err := yaml.Unmarshal(data, &f); err != nil {
|
||||
return File{}, fmt.Errorf("parse config %s: %w", path, err)
|
||||
}
|
||||
if err := loadExternalSecrets(path, &f); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func loadExternalSecrets(configPath string, f *File) error {
|
||||
if f.Crypto.KeyFile == "" {
|
||||
return loadProfileSecrets(configPath, f.Profiles)
|
||||
}
|
||||
if f.Crypto.Key != "" {
|
||||
return ErrCryptoKeyConflict
|
||||
}
|
||||
|
||||
key, err := readKeyFile(configPath, f.Crypto.KeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Crypto.Key = key
|
||||
return loadProfileSecrets(configPath, f.Profiles)
|
||||
}
|
||||
|
||||
func loadProfileSecrets(configPath string, profiles []Profile) error {
|
||||
for i := range profiles {
|
||||
if profiles[i].Crypto.KeyFile == "" {
|
||||
continue
|
||||
}
|
||||
if profiles[i].Crypto.Key != "" {
|
||||
return fmt.Errorf("profiles[%d]: %w", i, ErrCryptoKeyConflict)
|
||||
}
|
||||
key, err := readKeyFile(configPath, profiles[i].Crypto.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profiles[%d]: %w", i, err)
|
||||
}
|
||||
profiles[i].Crypto.Key = key
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readKeyFile(configPath, keyFile string) (string, error) {
|
||||
keyPath := keyFile
|
||||
if !filepath.IsAbs(keyPath) {
|
||||
keyPath = filepath.Join(filepath.Dir(configPath), keyPath)
|
||||
}
|
||||
|
||||
// #nosec G304 -- key_file is an explicit path in the user's config file.
|
||||
data, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read crypto key file %s: %w", keyPath, err)
|
||||
}
|
||||
key := strings.TrimSpace(string(data))
|
||||
if key == "" {
|
||||
return "", ErrCryptoKeyFileEmpty
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Apply merges f onto dst. CLI-set fields (non-zero values in dst) win;
|
||||
// YAML values fill in the rest.
|
||||
func Apply(dst session.Config, f File) session.Config {
|
||||
@@ -163,10 +273,61 @@ func Apply(dst session.Config, f File) session.Config {
|
||||
dst.SEIBatchSize = pickInt(dst.SEIBatchSize, f.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = pickInt(dst.SEIFragmentSize, f.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = pickInt(dst.SEIAckTimeoutMS, f.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = pickString(dst.LivenessInterval, f.Liveness.Interval)
|
||||
dst.LivenessTimeout = pickString(dst.LivenessTimeout, f.Liveness.Timeout)
|
||||
dst.LivenessFailures = pickInt(dst.LivenessFailures, f.Liveness.Failures)
|
||||
dst.MaxSessionDuration = pickString(dst.MaxSessionDuration, f.Lifecycle.MaxSessionDuration)
|
||||
dst.TrafficMaxPayloadSize = pickInt(dst.TrafficMaxPayloadSize, f.Traffic.MaxPayloadSize)
|
||||
dst.TrafficMinDelay = pickString(dst.TrafficMinDelay, f.Traffic.MinDelay)
|
||||
dst.TrafficMaxDelay = pickString(dst.TrafficMaxDelay, f.Traffic.MaxDelay)
|
||||
dst.Amount = pickInt(dst.Amount, f.Gen.Amount)
|
||||
return dst
|
||||
}
|
||||
|
||||
// ApplyProfile overlays a failover profile onto an already-applied base config.
|
||||
func ApplyProfile(base session.Config, p Profile) session.Config {
|
||||
dst := base
|
||||
dst.Link = overlayString(dst.Link, p.Link)
|
||||
dst.Transport = overlayString(dst.Transport, p.Net.Transport)
|
||||
dst.Auth = overlayString(dst.Auth, p.Auth.Provider)
|
||||
dst.Engine = overlayString(dst.Engine, p.Engine.Name)
|
||||
dst.URL = overlayString(dst.URL, p.Engine.URL)
|
||||
dst.Token = overlayString(dst.Token, p.Engine.Token)
|
||||
dst.RoomID = overlayString(dst.RoomID, p.Room.ID)
|
||||
dst.KeyHex = overlayString(dst.KeyHex, p.Crypto.Key)
|
||||
dst.SOCKSHost = overlayString(dst.SOCKSHost, p.SOCKS.Host)
|
||||
dst.SOCKSPort = overlayInt(dst.SOCKSPort, p.SOCKS.Port)
|
||||
dst.SOCKSUser = overlayString(dst.SOCKSUser, p.SOCKS.User)
|
||||
dst.SOCKSPass = overlayString(dst.SOCKSPass, p.SOCKS.Pass)
|
||||
dst.DNSServer = overlayString(dst.DNSServer, p.Net.DNS)
|
||||
dst.SOCKSProxyAddr = overlayString(dst.SOCKSProxyAddr, p.SOCKS.ProxyAddr)
|
||||
dst.SOCKSProxyPort = overlayInt(dst.SOCKSProxyPort, p.SOCKS.ProxyPort)
|
||||
dst.VideoWidth = overlayInt(dst.VideoWidth, p.Video.Width)
|
||||
dst.VideoHeight = overlayInt(dst.VideoHeight, p.Video.Height)
|
||||
dst.VideoFPS = overlayInt(dst.VideoFPS, p.Video.FPS)
|
||||
dst.VideoBitrate = overlayString(dst.VideoBitrate, p.Video.Bitrate)
|
||||
dst.VideoHW = overlayString(dst.VideoHW, p.Video.HW)
|
||||
dst.VideoQRSize = overlayInt(dst.VideoQRSize, p.Video.QRSize)
|
||||
dst.VideoQRRecovery = overlayString(dst.VideoQRRecovery, p.Video.QRRecovery)
|
||||
dst.VideoCodec = overlayString(dst.VideoCodec, p.Video.Codec)
|
||||
dst.VideoTileModule = overlayInt(dst.VideoTileModule, p.Video.TileModule)
|
||||
dst.VideoTileRS = overlayInt(dst.VideoTileRS, p.Video.TileRS)
|
||||
dst.VP8FPS = overlayInt(dst.VP8FPS, p.VP8.FPS)
|
||||
dst.VP8BatchSize = overlayInt(dst.VP8BatchSize, p.VP8.BatchSize)
|
||||
dst.SEIFPS = overlayInt(dst.SEIFPS, p.SEI.FPS)
|
||||
dst.SEIBatchSize = overlayInt(dst.SEIBatchSize, p.SEI.BatchSize)
|
||||
dst.SEIFragmentSize = overlayInt(dst.SEIFragmentSize, p.SEI.FragmentSize)
|
||||
dst.SEIAckTimeoutMS = overlayInt(dst.SEIAckTimeoutMS, p.SEI.AckTimeoutMS)
|
||||
dst.LivenessInterval = overlayString(dst.LivenessInterval, p.Liveness.Interval)
|
||||
dst.LivenessTimeout = overlayString(dst.LivenessTimeout, p.Liveness.Timeout)
|
||||
dst.LivenessFailures = overlayInt(dst.LivenessFailures, p.Liveness.Failures)
|
||||
dst.MaxSessionDuration = overlayString(dst.MaxSessionDuration, p.Lifecycle.MaxSessionDuration)
|
||||
dst.TrafficMaxPayloadSize = overlayInt(dst.TrafficMaxPayloadSize, p.Traffic.MaxPayloadSize)
|
||||
dst.TrafficMinDelay = overlayString(dst.TrafficMinDelay, p.Traffic.MinDelay)
|
||||
dst.TrafficMaxDelay = overlayString(dst.TrafficMaxDelay, p.Traffic.MaxDelay)
|
||||
return dst
|
||||
}
|
||||
|
||||
func pickString(cli, yamlVal string) string {
|
||||
if cli != "" {
|
||||
return cli
|
||||
@@ -180,3 +341,17 @@ func pickInt(cli, yamlVal int) int {
|
||||
}
|
||||
return yamlVal
|
||||
}
|
||||
|
||||
func overlayString(base, override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func overlayInt(base, override int) int {
|
||||
if override != 0 {
|
||||
return override
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -38,6 +39,16 @@ socks:
|
||||
vp8:
|
||||
fps: 25
|
||||
batch_size: 4
|
||||
liveness:
|
||||
interval: 2s
|
||||
timeout: 500ms
|
||||
failures: 4
|
||||
lifecycle:
|
||||
max_session_duration: 6h
|
||||
traffic:
|
||||
max_payload_size: 4096
|
||||
min_delay: 5ms
|
||||
max_delay: 30ms
|
||||
gen:
|
||||
amount: 3
|
||||
debug: true
|
||||
@@ -75,20 +86,27 @@ func requireLoadedFile(t *testing.T, f File) {
|
||||
func requireAppliedConfig(t *testing.T, got session.Config) {
|
||||
t.Helper()
|
||||
want := session.Config{
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
Amount: 3,
|
||||
Mode: testModeSrv,
|
||||
Link: "direct",
|
||||
Auth: testAuthProvider,
|
||||
RoomID: testRoomID,
|
||||
KeyHex: testCryptoKey,
|
||||
Transport: "datachannel",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
SOCKSHost: "127.0.0.1",
|
||||
SOCKSPort: 1080,
|
||||
SOCKSUser: "u",
|
||||
SOCKSPass: "p",
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 4,
|
||||
LivenessInterval: "2s",
|
||||
LivenessTimeout: "500ms",
|
||||
LivenessFailures: 4,
|
||||
MaxSessionDuration: "6h",
|
||||
TrafficMaxPayloadSize: 4096,
|
||||
TrafficMinDelay: "5ms",
|
||||
TrafficMaxDelay: "30ms",
|
||||
Amount: 3,
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("Apply produced wrong config: %+v, want %+v", got, want)
|
||||
@@ -121,6 +139,182 @@ func TestApplyCLIWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAndApplyProfile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
mode: srv
|
||||
link: direct
|
||||
crypto:
|
||||
key: shared-key
|
||||
net:
|
||||
dns: 1.1.1.1:53
|
||||
liveness:
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
failures: 5
|
||||
lifecycle:
|
||||
max_session_duration: 6h
|
||||
traffic:
|
||||
max_payload_size: 8192
|
||||
min_delay: 10ms
|
||||
max_delay: 40ms
|
||||
profiles:
|
||||
- name: wb-vp8
|
||||
auth:
|
||||
provider: wbstream
|
||||
room:
|
||||
id: wb-room
|
||||
net:
|
||||
transport: vp8channel
|
||||
vp8:
|
||||
fps: 30
|
||||
liveness:
|
||||
interval: 1s
|
||||
lifecycle:
|
||||
max_session_duration: 30m
|
||||
traffic:
|
||||
max_payload_size: 4096
|
||||
max_delay: 20ms
|
||||
- name: jitsi-dc
|
||||
auth:
|
||||
provider: jitsi
|
||||
room:
|
||||
id: https://meet.example/room
|
||||
net:
|
||||
transport: datachannel
|
||||
dns: 8.8.8.8:53
|
||||
failover:
|
||||
retry_delay: 100ms
|
||||
max_cycles: 2
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if len(f.Profiles) != 2 {
|
||||
t.Fatalf("profiles = %d, want 2", len(f.Profiles))
|
||||
}
|
||||
if f.Failover.RetryDelay != "100ms" || f.Failover.MaxCycles != 2 {
|
||||
t.Fatalf("Failover = %+v, want retry_delay 100ms max_cycles 2", f.Failover)
|
||||
}
|
||||
|
||||
base := Apply(session.Config{}, f)
|
||||
first := ApplyProfile(base, f.Profiles[0])
|
||||
if first.Auth != "wbstream" || first.Transport != "vp8channel" || first.RoomID != "wb-room" {
|
||||
t.Fatalf("first profile = %+v", first)
|
||||
}
|
||||
if first.KeyHex != "shared-key" || first.DNSServer != "1.1.1.1:53" || first.VP8FPS != 30 ||
|
||||
first.LivenessInterval != "1s" || first.LivenessTimeout != "2s" || first.LivenessFailures != 5 ||
|
||||
first.MaxSessionDuration != "30m" || first.TrafficMaxPayloadSize != 4096 ||
|
||||
first.TrafficMinDelay != "10ms" || first.TrafficMaxDelay != "20ms" {
|
||||
t.Fatalf("first inherited/overlaid fields = %+v", first)
|
||||
}
|
||||
second := ApplyProfile(base, f.Profiles[1])
|
||||
if second.Auth != "jitsi" || second.Transport != "datachannel" ||
|
||||
second.RoomID != "https://meet.example/room" || second.DNSServer != "8.8.8.8:53" {
|
||||
t.Fatalf("second profile = %+v", second)
|
||||
}
|
||||
if second.LivenessInterval != "5s" || second.LivenessTimeout != "2s" || second.LivenessFailures != 5 ||
|
||||
second.MaxSessionDuration != "6h" || second.TrafficMaxPayloadSize != 8192 ||
|
||||
second.TrafficMinDelay != "10ms" || second.TrafficMaxDelay != "40ms" {
|
||||
t.Fatalf("second lifecycle/liveness fields = %+v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProfileCryptoKeyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "profile.key"), []byte(testCryptoKey+"\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
profiles:
|
||||
- name: file-key
|
||||
crypto:
|
||||
key_file: profile.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if got := f.Profiles[0].Crypto.Key; got != testCryptoKey {
|
||||
t.Fatalf("profile key = %q, want %q", got, testCryptoKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileRelativeToConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
if err := os.WriteFile(keyPath, []byte(testCryptoKey+"\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
mode: srv
|
||||
crypto:
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if f.Crypto.Key != testCryptoKey {
|
||||
t.Fatalf("Crypto.Key = %q, want %q", f.Crypto.Key, testCryptoKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
crypto:
|
||||
key: deadbeef
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if !errors.Is(err, ErrCryptoKeyConflict) {
|
||||
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyConflict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCryptoKeyFileEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
if err := os.WriteFile(keyPath, []byte("\n"), 0o600); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
path := filepath.Join(dir, "olcrtc.yaml")
|
||||
body := `
|
||||
crypto:
|
||||
key_file: secret.key
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if !errors.Is(err, ErrCryptoKeyFileEmpty) {
|
||||
t.Fatalf("Load() error = %v, want %v", err, ErrCryptoKeyFileEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissing(t *testing.T) {
|
||||
_, err := Load(filepath.Join(t.TempDir(), "nope.yaml"))
|
||||
if err == nil {
|
||||
|
||||
343
internal/control/control.go
Normal file
343
internal/control/control.go
Normal file
@@ -0,0 +1,343 @@
|
||||
// Package control implements the post-handshake control stream protocol.
|
||||
//
|
||||
// The control stream is the first smux stream after the olcrtc handshake. It
|
||||
// stays inside the encrypted muxconn path, so ping/pong proves that the actual
|
||||
// tunnel path still round-trips, not merely that the provider connection is up.
|
||||
//
|
||||
// Wire format matches the handshake framing: a 4-byte big-endian length
|
||||
// followed by a JSON message.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// ProtoVersion identifies the control stream wire format.
|
||||
ProtoVersion = 1
|
||||
// MaxMessageSize caps one control frame.
|
||||
MaxMessageSize = 16 * 1024
|
||||
// DefaultInterval is the default interval between ping probes.
|
||||
DefaultInterval = 10 * time.Second
|
||||
// DefaultTimeout is the default time to wait for a pong.
|
||||
DefaultTimeout = 5 * time.Second
|
||||
// DefaultFailures is the default number of consecutive missed pongs before
|
||||
// the stream is marked unhealthy.
|
||||
DefaultFailures = 3
|
||||
)
|
||||
|
||||
// MsgType labels a control message.
|
||||
type MsgType string
|
||||
|
||||
const (
|
||||
// TypePing is sent periodically to prove control-stream liveness.
|
||||
TypePing MsgType = "CONTROL_PING"
|
||||
// TypePong replies to a ping with the same sequence and timestamp.
|
||||
TypePong MsgType = "CONTROL_PONG"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnhealthy is returned when the stream misses too many pong replies.
|
||||
ErrUnhealthy = errors.New("control stream unhealthy")
|
||||
// ErrProtocolVersion is returned when the peer announces an incompatible version.
|
||||
ErrProtocolVersion = errors.New("incompatible control protocol version")
|
||||
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
|
||||
ErrUnexpectedMessage = errors.New("unexpected control message")
|
||||
// ErrFrameTooLarge is returned when a frame exceeds [MaxMessageSize].
|
||||
ErrFrameTooLarge = errors.New("control frame too large")
|
||||
)
|
||||
|
||||
// Message is one control-stream frame.
|
||||
type Message struct {
|
||||
Version int `json:"version"`
|
||||
Type MsgType `json:"type"`
|
||||
Seq uint64 `json:"seq,omitempty"`
|
||||
SentUnixNano int64 `json:"sent_unix_nano,omitempty"`
|
||||
}
|
||||
|
||||
// Health is reported when a ping round trip completes.
|
||||
type Health struct {
|
||||
Seq uint64
|
||||
RTT time.Duration
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Status is a point-in-time view of control-stream health maintained by
|
||||
// callers that embed the control loop.
|
||||
type Status struct {
|
||||
SessionID string
|
||||
LastPong time.Time
|
||||
LastRTT time.Duration
|
||||
MissedPongs int
|
||||
Reconnects uint64
|
||||
UnhealthyEvents uint64
|
||||
LastUnhealthy time.Time
|
||||
}
|
||||
|
||||
// Config controls the liveness loop.
|
||||
type Config struct {
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
Failures int
|
||||
|
||||
// OnPong is called after a matching pong is received.
|
||||
OnPong func(Health)
|
||||
// OnMissedPong is called when one or more outstanding pongs time out.
|
||||
OnMissedPong func(missed int)
|
||||
// OnUnhealthy is called before Run returns [ErrUnhealthy].
|
||||
OnUnhealthy func(missed int)
|
||||
}
|
||||
|
||||
func (cfg Config) withDefaults() Config {
|
||||
if cfg.Interval <= 0 {
|
||||
cfg.Interval = DefaultInterval
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = DefaultTimeout
|
||||
}
|
||||
if cfg.Failures <= 0 {
|
||||
cfg.Failures = DefaultFailures
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Run drives bidirectional ping/pong liveness until ctx is canceled, rw closes,
|
||||
// or the configured failure threshold is reached.
|
||||
func Run(ctx context.Context, rw io.ReadWriteCloser, cfg Config) error {
|
||||
cfg = cfg.withDefaults()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
state := &state{
|
||||
rw: rw,
|
||||
cfg: cfg,
|
||||
pending: make(map[uint64]time.Time),
|
||||
now: time.Now,
|
||||
out: make(chan Message, 16),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 3)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = rw.Close()
|
||||
}()
|
||||
go func() { errCh <- state.readLoop(ctx) }()
|
||||
go func() { errCh <- state.probeLoop(ctx) }()
|
||||
go func() { errCh <- state.writeLoop(ctx) }()
|
||||
|
||||
err := <-errCh
|
||||
cancel()
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type state struct {
|
||||
rw io.ReadWriteCloser
|
||||
cfg Config
|
||||
now func() time.Time
|
||||
|
||||
out chan Message
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[uint64]time.Time
|
||||
nextSeq uint64
|
||||
failures int
|
||||
}
|
||||
|
||||
func (s *state) readLoop(ctx context.Context) error {
|
||||
for {
|
||||
raw, err := readFrame(s.rw)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
msg, err := parseMessage(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Type {
|
||||
case TypePing:
|
||||
if err := s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePong,
|
||||
Seq: msg.Seq,
|
||||
SentUnixNano: msg.SentUnixNano,
|
||||
}); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
case TypePong:
|
||||
s.handlePong(msg)
|
||||
default:
|
||||
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) probeLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(s.cfg.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if err := s.sendProbe(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) sendProbe(ctx context.Context) error {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
missedNow := 0
|
||||
for seq, sent := range s.pending {
|
||||
if now.Sub(sent) < s.cfg.Timeout {
|
||||
continue
|
||||
}
|
||||
delete(s.pending, seq)
|
||||
s.failures++
|
||||
missedNow++
|
||||
}
|
||||
missed := s.failures
|
||||
if s.failures >= s.cfg.Failures {
|
||||
s.mu.Unlock()
|
||||
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
||||
s.cfg.OnMissedPong(missed)
|
||||
}
|
||||
if s.cfg.OnUnhealthy != nil {
|
||||
s.cfg.OnUnhealthy(missed)
|
||||
}
|
||||
return fmt.Errorf("%w: missed %d pong(s)", ErrUnhealthy, missed)
|
||||
}
|
||||
|
||||
s.nextSeq++
|
||||
seq := s.nextSeq
|
||||
s.pending[seq] = now
|
||||
s.mu.Unlock()
|
||||
if missedNow > 0 && s.cfg.OnMissedPong != nil {
|
||||
s.cfg.OnMissedPong(missed)
|
||||
}
|
||||
|
||||
return s.enqueue(ctx, Message{
|
||||
Version: ProtoVersion,
|
||||
Type: TypePing,
|
||||
Seq: seq,
|
||||
SentUnixNano: now.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) handlePong(msg Message) {
|
||||
now := s.now()
|
||||
|
||||
s.mu.Lock()
|
||||
sent, ok := s.pending[msg.Seq]
|
||||
if ok {
|
||||
delete(s.pending, msg.Seq)
|
||||
s.failures = 0
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || s.cfg.OnPong == nil {
|
||||
return
|
||||
}
|
||||
s.cfg.OnPong(Health{
|
||||
Seq: msg.Seq,
|
||||
RTT: now.Sub(sent),
|
||||
LastSeen: now,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *state) enqueue(ctx context.Context, msg Message) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.out <- msg:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) writeLoop(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg := <-s.out:
|
||||
if err := writeFrame(s.rw, msg); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseMessage(raw []byte) (Message, error) {
|
||||
var msg Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
return Message{}, fmt.Errorf("parse control message: %w", err)
|
||||
}
|
||||
if msg.Version != ProtoVersion {
|
||||
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
|
||||
ErrProtocolVersion, msg.Version, ProtoVersion)
|
||||
}
|
||||
if msg.Type != TypePing && msg.Type != TypePong {
|
||||
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, msg Message) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal control message: %w", err)
|
||||
}
|
||||
if len(body) > MaxMessageSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, len(body), MaxMessageSize)
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(body))) //nolint:gosec // len(body) bounded by MaxMessageSize
|
||||
if _, err := w.Write(hdr[:]); err != nil {
|
||||
return fmt.Errorf("write control hdr: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
return fmt.Errorf("write control body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) ([]byte, error) {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read control hdr: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(hdr[:])
|
||||
if n > MaxMessageSize {
|
||||
return nil, fmt.Errorf("%w: %d > %d", ErrFrameTooLarge, n, MaxMessageSize)
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return nil, fmt.Errorf("read control body: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
138
internal/control/control_test.go
Normal file
138
internal/control/control_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func controlPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
a, b := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
})
|
||||
return a, b
|
||||
}
|
||||
|
||||
func TestRunPingPongReportsRTT(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
got := make(chan Health, 1)
|
||||
cfg := Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- Run(ctx, a, cfg) }()
|
||||
go func() { errCh <- Run(ctx, b, cfg) }()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
if h.RTT < 0 {
|
||||
t.Fatalf("Health.RTT = %v", h.RTT)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for pong health")
|
||||
}
|
||||
|
||||
cancel()
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("Run() after cancel = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMarksUnhealthyAfterMissedPongs(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, b)
|
||||
}()
|
||||
|
||||
missedCh := make(chan int, 1)
|
||||
missedCallbackCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(ctx, a, Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnMissedPong: func(missed int) {
|
||||
select {
|
||||
case missedCallbackCh <- missed:
|
||||
default:
|
||||
}
|
||||
},
|
||||
OnUnhealthy: func(missed int) { missedCh <- missed },
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrUnhealthy) {
|
||||
t.Fatalf("Run() error = %v, want ErrUnhealthy", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for unhealthy result")
|
||||
}
|
||||
if missed := <-missedCh; missed < 2 {
|
||||
t.Fatalf("missed = %d, want >= 2", missed)
|
||||
}
|
||||
if missed := <-missedCallbackCh; missed < 1 {
|
||||
t.Fatalf("missed callback = %d, want >= 1", missed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsBadProtocolVersion(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
|
||||
}()
|
||||
if err := writeFrame(b, Message{Version: 999, Type: TypePing, Seq: 1}); err != nil {
|
||||
t.Fatalf("writeFrame() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrProtocolVersion) {
|
||||
t.Fatalf("Run() error = %v, want ErrProtocolVersion", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsTooLarge(t *testing.T) {
|
||||
a, b := controlPair(t)
|
||||
go func() {
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], MaxMessageSize+1)
|
||||
_, _ = b.Write(hdr[:])
|
||||
}()
|
||||
_, err := readFrame(a)
|
||||
if !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("readFrame() error = %v, want ErrFrameTooLarge", err)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// WireOverhead is the number of bytes added to each encrypted message.
|
||||
const WireOverhead = chacha20poly1305.NonceSizeX + chacha20poly1305.Overhead
|
||||
|
||||
var (
|
||||
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrInvalidKeySize = errors.New("invalid key size")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/client"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/server"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/supervisor"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
@@ -47,6 +48,7 @@ var (
|
||||
errSocksUnexpectedReply = errors.New("unexpected SOCKS5 reply")
|
||||
errSocksUnexpectedHello = errors.New("unexpected SOCKS5 greeting")
|
||||
errPayloadMismatchOffset = errors.New("payload mismatch at offset")
|
||||
errFailoverCarrier = errors.New("intentional failover carrier failure")
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -347,6 +349,17 @@ func registerMemoryCarrierAs(t *testing.T, name string) {
|
||||
})
|
||||
}
|
||||
|
||||
func registerFailingCarrier(t *testing.T) string {
|
||||
t.Helper()
|
||||
session.RegisterDefaults()
|
||||
|
||||
name := "e2e-fail-" + t.Name()
|
||||
carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) {
|
||||
return nil, errFailoverCarrier
|
||||
})
|
||||
return name
|
||||
}
|
||||
|
||||
func builtInCarrierNames() []string {
|
||||
return []string{"jazz", "telemost", "wbstream", "jitsi"} //nolint:goconst // test literal, repetition is intentional
|
||||
}
|
||||
@@ -1008,9 +1021,7 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
|
||||
if err := ln.Connect(context.Background()); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
assertLinkCanSendAfterConnect(t, ln, transportName)
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
@@ -1020,6 +1031,20 @@ func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertLinkCanSendAfterConnect(t *testing.T, ln link.Link, transportName string) {
|
||||
t.Helper()
|
||||
|
||||
if transportName == transportSEI {
|
||||
if ln.CanSend() {
|
||||
t.Fatal("CanSend() = true before peer seichannel frame")
|
||||
}
|
||||
return
|
||||
}
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // table-driven test naturally has many branches
|
||||
func TestRealProviderTransportMatrix(t *testing.T) {
|
||||
if !*realE2E {
|
||||
@@ -1163,6 +1188,186 @@ func TestFrequentReconnectsStillAllowNewSOCKSConnections(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupervisorFailoverProfilesReachWorkingSOCKS(t *testing.T) {
|
||||
echoAddr := startEchoServer(t)
|
||||
failingCarrier := registerFailingCarrier(t)
|
||||
memoryCarrier, room := registerMemoryCarrier(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
socksAddr := freeLocalAddr(ctx, t)
|
||||
socksHost, socksPort := splitHostPort(t, socksAddr)
|
||||
|
||||
serverProfiles := []supervisor.Profile{
|
||||
{Name: "failing-server", Config: failoverSessionConfig("srv", failingCarrier, "", 0)},
|
||||
{Name: "memory-server", Config: failoverSessionConfig("srv", memoryCarrier, "", 0)},
|
||||
}
|
||||
clientProfiles := []supervisor.Profile{
|
||||
{Name: "failing-client", Config: failoverSessionConfig("cnc", failingCarrier, socksHost, socksPort)},
|
||||
{Name: "memory-client", Config: failoverSessionConfig("cnc", memoryCarrier, socksHost, socksPort)},
|
||||
}
|
||||
|
||||
started := make(chan string, 8)
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- supervisor.Run(ctx, failoverE2EConfig(serverProfiles, started, "server"), session.Run)
|
||||
}()
|
||||
room.waitConnected(t, 1)
|
||||
|
||||
ready := make(chan struct{})
|
||||
var readyOnce sync.Once
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
clientErr <- supervisor.Run(ctx, failoverE2EConfig(clientProfiles, started, "client"), func(ctx context.Context, cfg session.Config) error {
|
||||
return client.RunWithReady(ctx, clientConfigFromSession(cfg, socksAddr), func() {
|
||||
if cfg.Auth == memoryCarrier {
|
||||
readyOnce.Do(func() { close(ready) })
|
||||
}
|
||||
})
|
||||
})
|
||||
}()
|
||||
|
||||
waitForReady(t, ready)
|
||||
conn := eventuallyConnectViaSOCKS(t, socksAddr, echoAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
payload := []byte("olcrtc-failover-e2e\n")
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
t.Fatalf("write failover payload: %v", err)
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||
t.Fatalf("set failover read deadline: %v", err)
|
||||
}
|
||||
line, err := bufio.NewReader(conn).ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read failover echo: %v", err)
|
||||
}
|
||||
if !bytes.Equal(line, payload) {
|
||||
t.Fatalf("failover echo = %q, want %q", line, payload)
|
||||
}
|
||||
|
||||
requireStartedProfiles(t, started, []string{
|
||||
"server:failing-server",
|
||||
"server:memory-server",
|
||||
"client:failing-client",
|
||||
"client:memory-client",
|
||||
})
|
||||
|
||||
cancel()
|
||||
waitSupervisorStopped(t, "client", clientErr)
|
||||
waitSupervisorStopped(t, "server", serverErr)
|
||||
}
|
||||
|
||||
func failoverSessionConfig(mode, carrierName, socksHost string, socksPort int) session.Config {
|
||||
cfg := session.Config{
|
||||
Mode: mode,
|
||||
Link: linkDirect,
|
||||
Transport: transportData,
|
||||
Auth: carrierName,
|
||||
RoomID: testRoom,
|
||||
KeyHex: testKeyHex,
|
||||
DNSServer: localDNSServer,
|
||||
}
|
||||
if mode == "cnc" {
|
||||
cfg.SOCKSHost = socksHost
|
||||
cfg.SOCKSPort = socksPort
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func clientConfigFromSession(cfg session.Config, socksAddr string) client.Config {
|
||||
return client.Config{
|
||||
Link: cfg.Link,
|
||||
Transport: cfg.Transport,
|
||||
Carrier: cfg.Auth,
|
||||
RoomURL: cfg.RoomID,
|
||||
KeyHex: cfg.KeyHex,
|
||||
LocalAddr: socksAddr,
|
||||
DNSServer: cfg.DNSServer,
|
||||
DeviceID: testClientDeviceID,
|
||||
VideoWidth: cfg.VideoWidth,
|
||||
VideoHeight: cfg.VideoHeight,
|
||||
VideoFPS: cfg.VideoFPS,
|
||||
VideoBitrate: cfg.VideoBitrate,
|
||||
VideoHW: cfg.VideoHW,
|
||||
VideoQRSize: cfg.VideoQRSize,
|
||||
VideoQRRecovery: cfg.VideoQRRecovery,
|
||||
VideoCodec: cfg.VideoCodec,
|
||||
VideoTileModule: cfg.VideoTileModule,
|
||||
VideoTileRS: cfg.VideoTileRS,
|
||||
VP8FPS: cfg.VP8FPS,
|
||||
VP8BatchSize: cfg.VP8BatchSize,
|
||||
SEIFPS: cfg.SEIFPS,
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Engine: cfg.Engine,
|
||||
URL: cfg.URL,
|
||||
Token: cfg.Token,
|
||||
}
|
||||
}
|
||||
|
||||
func failoverE2EConfig(
|
||||
profiles []supervisor.Profile,
|
||||
started chan<- string,
|
||||
side string,
|
||||
) supervisor.Config {
|
||||
return supervisor.Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: time.Millisecond,
|
||||
OnProfileStart: func(profile supervisor.Profile, _ int) {
|
||||
select {
|
||||
case started <- side + ":" + profile.Name:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func splitHostPort(t *testing.T, addr string) (string, int) {
|
||||
t.Helper()
|
||||
host, portText, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("split host port %q: %v", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
t.Fatalf("parse port %q: %v", portText, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func requireStartedProfiles(t *testing.T, started <-chan string, want []string) {
|
||||
t.Helper()
|
||||
seen := make(map[string]bool)
|
||||
deadline := time.After(3 * time.Second)
|
||||
for len(seen) < len(want) {
|
||||
select {
|
||||
case item := <-started:
|
||||
seen[item] = true
|
||||
case <-deadline:
|
||||
t.Fatalf("started profiles = %v, want all %v", seen, want)
|
||||
}
|
||||
}
|
||||
for _, item := range want {
|
||||
if !seen[item] {
|
||||
t.Fatalf("started profiles = %v, missing %s", seen, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitSupervisorStopped(t *testing.T, name string, ch <-chan error) {
|
||||
t.Helper()
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("%s supervisor returned error: %v", name, err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("%s supervisor did not stop", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndedCallbackStopsClientAndServer(t *testing.T) {
|
||||
rt := startTunnel(t)
|
||||
rt.room.triggerEnded("conference ended")
|
||||
|
||||
@@ -112,10 +112,7 @@ func (s *Session) setupPeerConnections(config webrtc.Configuration) error {
|
||||
}
|
||||
|
||||
func (s *Session) dialWebSocket() error {
|
||||
wsDialer := websocket.Dialer{
|
||||
NetDialContext: protect.DialContext,
|
||||
HandshakeTimeout: wsHandshakeTimeout,
|
||||
}
|
||||
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
|
||||
ws, resp, err := wsDialer.Dial(s.mediaServerURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial ws: %w", err)
|
||||
|
||||
@@ -19,13 +19,17 @@ import (
|
||||
protoLogger "github.com/livekit/protocol/logger"
|
||||
lksdk "github.com/livekit/server-sdk-go/v2"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/engine"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSendQueueSize = 5000
|
||||
dataPublishTopic = "olcrtc"
|
||||
videoTrackName = "videochannel"
|
||||
defaultSendQueueSize = 5000
|
||||
defaultSendQueueCapHard = 4000
|
||||
dataPublishTopic = "olcrtc"
|
||||
videoTrackName = "videochannel"
|
||||
reconnectWindow = 5 * time.Minute
|
||||
maxReconnects = 10
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,20 +45,98 @@ var (
|
||||
ErrTokenRequired = errors.New("livekit access token required")
|
||||
)
|
||||
|
||||
type roomHandle interface {
|
||||
publishData([]byte) error
|
||||
publishTrack(webrtc.TrackLocal) error
|
||||
unpublishLocalTracks()
|
||||
disconnect()
|
||||
connectionState() lksdk.ConnectionState
|
||||
}
|
||||
|
||||
type sdkRoom struct {
|
||||
room *lksdk.Room
|
||||
}
|
||||
|
||||
func (r *sdkRoom) publishData(data []byte) error {
|
||||
return r.room.LocalParticipant.PublishDataPacket(
|
||||
lksdk.UserData(data),
|
||||
lksdk.WithDataPublishTopic(dataPublishTopic),
|
||||
lksdk.WithDataPublishReliable(true),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *sdkRoom) publishTrack(track webrtc.TrackLocal) error {
|
||||
_, err := r.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{Name: videoTrackName})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *sdkRoom) unpublishLocalTracks() {
|
||||
if r.room == nil || r.room.LocalParticipant == nil {
|
||||
return
|
||||
}
|
||||
for _, publication := range r.room.LocalParticipant.TrackPublications() {
|
||||
if publication.SID() == "" {
|
||||
continue
|
||||
}
|
||||
if err := r.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
|
||||
log.Printf("livekit unpublish track error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *sdkRoom) disconnect() {
|
||||
r.room.Disconnect()
|
||||
// LiveKit's Disconnect returns after local SDK teardown, before the
|
||||
// server necessarily evicts the participant. Give the signalling path a
|
||||
// short grace period so immediate reconnects do not inherit stale room
|
||||
// state from a ghost participant.
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
func (r *sdkRoom) connectionState() lksdk.ConnectionState {
|
||||
return r.room.ConnectionState()
|
||||
}
|
||||
|
||||
type connectRoomFunc func(url, token string, callback *lksdk.RoomCallback) (roomHandle, error)
|
||||
|
||||
func connectSDKRoom(url, token string, callback *lksdk.RoomCallback) (roomHandle, error) {
|
||||
room, err := lksdk.ConnectToRoomWithToken(
|
||||
url,
|
||||
token,
|
||||
callback,
|
||||
lksdk.WithAutoSubscribe(true),
|
||||
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sdkRoom{room: room}, nil
|
||||
}
|
||||
|
||||
// Session is the LiveKit engine handle.
|
||||
type Session struct {
|
||||
url string
|
||||
token string
|
||||
name string
|
||||
room *lksdk.Room
|
||||
refresh func(ctx context.Context) (engine.Credentials, error)
|
||||
connectRoom connectRoomFunc
|
||||
room roomHandle
|
||||
roomMu sync.RWMutex
|
||||
onData func([]byte)
|
||||
onReconnect func(*webrtc.DataChannel)
|
||||
shouldReconnect func() bool
|
||||
onEnded func(string)
|
||||
reconnectCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
lastReconnect time.Time
|
||||
reconnectCount int
|
||||
sendQueue chan []byte
|
||||
closed atomic.Bool
|
||||
reconnecting atomic.Bool
|
||||
done chan struct{}
|
||||
cancel context.CancelFunc
|
||||
shutdownOnce sync.Once
|
||||
sendWorkerOnce sync.Once
|
||||
videoTrackMu sync.RWMutex
|
||||
videoTracks []webrtc.TrackLocal
|
||||
onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
|
||||
@@ -71,13 +153,17 @@ func New(ctx context.Context, cfg engine.Config) (engine.Session, error) {
|
||||
}
|
||||
_, cancel := context.WithCancel(ctx)
|
||||
return &Session{
|
||||
url: cfg.URL,
|
||||
token: cfg.Token,
|
||||
name: cfg.Name,
|
||||
onData: cfg.OnData,
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
url: cfg.URL,
|
||||
token: cfg.Token,
|
||||
name: cfg.Name,
|
||||
refresh: cfg.Refresh,
|
||||
connectRoom: connectSDKRoom,
|
||||
onData: cfg.OnData,
|
||||
reconnectCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -87,7 +173,16 @@ func (s *Session) Capabilities() engine.Capabilities {
|
||||
}
|
||||
|
||||
// Connect joins the LiveKit room.
|
||||
func (s *Session) Connect(_ context.Context) error {
|
||||
func (s *Session) Connect(ctx context.Context) error {
|
||||
s.closed.Store(false)
|
||||
if err := s.connectSession(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
s.startSendWorker()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) connectSession(_ context.Context) error {
|
||||
roomCB := &lksdk.RoomCallback{
|
||||
ParticipantCallback: lksdk.ParticipantCallback{
|
||||
OnDataReceived: func(data []byte, _ lksdk.DataReceiveParams) {
|
||||
@@ -108,45 +203,49 @@ func (s *Session) Connect(_ context.Context) error {
|
||||
},
|
||||
},
|
||||
OnDisconnected: func() {
|
||||
if !s.closed.Load() && s.onEnded != nil {
|
||||
s.onEnded("disconnected from livekit")
|
||||
if s.closed.Load() || s.reconnecting.Load() {
|
||||
return
|
||||
}
|
||||
if !s.queueReconnect() {
|
||||
s.signalEnded("disconnected from livekit")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
room, err := lksdk.ConnectToRoomWithToken(
|
||||
s.url,
|
||||
s.token,
|
||||
roomCB,
|
||||
lksdk.WithAutoSubscribe(true),
|
||||
lksdk.WithLogger(protoLogger.GetDiscardLogger()),
|
||||
)
|
||||
room, err := s.connectRoom(s.url, s.token, roomCB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to room: %w", err)
|
||||
}
|
||||
|
||||
s.room = room
|
||||
s.setRoom(room)
|
||||
if err := s.publishPendingTracks(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.processSendQueue()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) publishPendingTracks() error {
|
||||
room := s.currentRoom()
|
||||
if room == nil {
|
||||
return ErrRoomNotConnected
|
||||
}
|
||||
s.videoTrackMu.RLock()
|
||||
defer s.videoTrackMu.RUnlock()
|
||||
for _, track := range s.videoTracks {
|
||||
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: videoTrackName,
|
||||
}); err != nil {
|
||||
if err := room.publishTrack(track); err != nil {
|
||||
return fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) startSendWorker() {
|
||||
s.sendWorkerOnce.Do(func() {
|
||||
s.wg.Add(1)
|
||||
go s.processSendQueue()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) processSendQueue() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
@@ -157,17 +256,33 @@ func (s *Session) processSendQueue() {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := s.room.LocalParticipant.PublishDataPacket(
|
||||
lksdk.UserData(data),
|
||||
lksdk.WithDataPublishTopic(dataPublishTopic),
|
||||
lksdk.WithDataPublishReliable(true),
|
||||
); err != nil {
|
||||
room := s.waitForConnectedRoom()
|
||||
if room == nil {
|
||||
return
|
||||
}
|
||||
if err := room.publishData(data); err != nil {
|
||||
log.Printf("livekit publish data error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) waitForConnectedRoom() roomHandle {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
room := s.currentRoom()
|
||||
if room != nil && room.connectionState() == lksdk.ConnectionStateConnected {
|
||||
return room
|
||||
}
|
||||
select {
|
||||
case <-s.done:
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queues data for transmission.
|
||||
func (s *Session) Send(data []byte) error {
|
||||
if s.closed.Load() {
|
||||
@@ -183,55 +298,160 @@ func (s *Session) Send(data []byte) error {
|
||||
|
||||
// Close terminates the session.
|
||||
func (s *Session) Close() error {
|
||||
if s.closed.CompareAndSwap(false, true) {
|
||||
s.cancel()
|
||||
close(s.done)
|
||||
if s.room != nil {
|
||||
s.unpublishLocalTracks()
|
||||
s.room.Disconnect()
|
||||
// LiveKit's Disconnect() returns once the local SDK state
|
||||
// is torn down, not when the server has actually evicted
|
||||
// the participant. Without giving the signalling channel
|
||||
// time to flush the LEAVE_REQUEST and the server to act on
|
||||
// it, a back-to-back reconnect from the same identity in
|
||||
// the same room sees a still-alive ghost participant on
|
||||
// the SFU and inherits stale publication state.
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
close(s.sendQueue)
|
||||
s.wg.Wait()
|
||||
}
|
||||
s.closed.Store(true)
|
||||
s.shutdown()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) unpublishLocalTracks() {
|
||||
if s.room == nil || s.room.LocalParticipant == nil {
|
||||
return
|
||||
}
|
||||
for _, publication := range s.room.LocalParticipant.TrackPublications() {
|
||||
if publication.SID() == "" {
|
||||
continue
|
||||
func (s *Session) shutdown() {
|
||||
s.shutdownOnce.Do(func() {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
if err := s.room.LocalParticipant.UnpublishTrack(publication.SID()); err != nil {
|
||||
log.Printf("livekit unpublish track error: %v", err)
|
||||
closeSignal(s.closeCh)
|
||||
closeSignal(s.done)
|
||||
if room := s.swapRoom(nil); room != nil {
|
||||
room.unpublishLocalTracks()
|
||||
room.disconnect()
|
||||
}
|
||||
}
|
||||
s.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// SetReconnectCallback stores the reconnect callback (LiveKit reconnects internally; this is kept for API parity).
|
||||
// SetReconnectCallback stores the reconnect callback.
|
||||
func (s *Session) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.onReconnect = cb }
|
||||
|
||||
// SetShouldReconnect stores the reconnect predicate (kept for API parity).
|
||||
// SetShouldReconnect stores the reconnect predicate.
|
||||
func (s *Session) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn }
|
||||
|
||||
// SetEndedCallback registers a function to call when the session ends.
|
||||
func (s *Session) SetEndedCallback(cb func(string)) { s.onEnded = cb }
|
||||
|
||||
// WatchConnection is a no-op; LiveKit handles connection supervision itself.
|
||||
func (s *Session) WatchConnection(_ context.Context) {}
|
||||
// WatchConnection monitors the connection lifecycle and reconnects as needed.
|
||||
func (s *Session) WatchConnection(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.closeCh:
|
||||
return
|
||||
case <-s.reconnectCh:
|
||||
if s.handleReconnectAttempt(ctx) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleReconnectAttempt(ctx context.Context) bool {
|
||||
if time.Since(s.lastReconnect) > reconnectWindow {
|
||||
s.reconnectCount = 0
|
||||
}
|
||||
s.reconnectCount++
|
||||
s.lastReconnect = time.Now()
|
||||
|
||||
if s.reconnectCount > maxReconnects {
|
||||
s.signalEnded("reconnect limit reached")
|
||||
return true
|
||||
}
|
||||
|
||||
backoff := time.Duration(s.reconnectCount) * 2 * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
if err := s.reconnect(ctx); err != nil {
|
||||
logger.Debugf("livekit reconnect failed: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
case <-s.closeCh:
|
||||
return true
|
||||
case <-time.After(backoff):
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.drainReconnectQueue()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) reconnect(ctx context.Context) error {
|
||||
s.reconnecting.Store(true)
|
||||
defer s.reconnecting.Store(false)
|
||||
|
||||
if room := s.swapRoom(nil); room != nil {
|
||||
room.unpublishLocalTracks()
|
||||
room.disconnect()
|
||||
}
|
||||
|
||||
if s.refresh != nil {
|
||||
creds, err := s.refresh(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("refresh credentials: %w", err)
|
||||
}
|
||||
s.applyRefreshedCredentials(creds)
|
||||
}
|
||||
|
||||
if err := s.connectSession(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.onReconnect != nil {
|
||||
s.onReconnect(nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) applyRefreshedCredentials(creds engine.Credentials) {
|
||||
if creds.URL != "" {
|
||||
s.url = creds.URL
|
||||
}
|
||||
if creds.Token != "" {
|
||||
s.token = creds.Token
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) queueReconnect() bool {
|
||||
if s.closed.Load() || s.reconnecting.Load() {
|
||||
return false
|
||||
}
|
||||
if s.shouldReconnect != nil && !s.shouldReconnect() {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.reconnectCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Session) drainReconnectQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-s.reconnectCh:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) signalEnded(reason string) {
|
||||
s.closed.Store(true)
|
||||
s.shutdown()
|
||||
if s.onEnded != nil {
|
||||
s.onEnded(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// CanSend reports whether the session is ready to accept data.
|
||||
func (s *Session) CanSend() bool { return !s.closed.Load() && s.room != nil }
|
||||
func (s *Session) CanSend() bool {
|
||||
if s.closed.Load() || s.reconnecting.Load() || len(s.sendQueue) >= defaultSendQueueCapHard {
|
||||
return false
|
||||
}
|
||||
room := s.currentRoom()
|
||||
return room != nil && room.connectionState() == lksdk.ConnectionStateConnected
|
||||
}
|
||||
|
||||
// GetSendQueue exposes the outbound queue.
|
||||
func (s *Session) GetSendQueue() chan []byte { return s.sendQueue }
|
||||
@@ -245,12 +465,11 @@ func (s *Session) AddVideoTrack(track webrtc.TrackLocal) error {
|
||||
s.videoTracks = append(s.videoTracks, track)
|
||||
s.videoTrackMu.Unlock()
|
||||
|
||||
if s.room == nil || s.room.LocalParticipant == nil {
|
||||
room := s.currentRoom()
|
||||
if room == nil {
|
||||
return nil
|
||||
}
|
||||
if _, err := s.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: videoTrackName,
|
||||
}); err != nil {
|
||||
if err := room.publishTrack(track); err != nil {
|
||||
return fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -263,6 +482,34 @@ func (s *Session) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPR
|
||||
s.onVideoTrack = cb
|
||||
}
|
||||
|
||||
func (s *Session) currentRoom() roomHandle {
|
||||
s.roomMu.RLock()
|
||||
defer s.roomMu.RUnlock()
|
||||
return s.room
|
||||
}
|
||||
|
||||
func (s *Session) setRoom(room roomHandle) {
|
||||
s.roomMu.Lock()
|
||||
defer s.roomMu.Unlock()
|
||||
s.room = room
|
||||
}
|
||||
|
||||
func (s *Session) swapRoom(room roomHandle) roomHandle {
|
||||
s.roomMu.Lock()
|
||||
defer s.roomMu.Unlock()
|
||||
old := s.room
|
||||
s.room = room
|
||||
return old
|
||||
}
|
||||
|
||||
func closeSignal(ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func init() { //nolint:gochecknoinits // engine registration is the canonical Go pattern for plugins
|
||||
engine.Register("livekit", New)
|
||||
}
|
||||
|
||||
306
internal/engine/livekit/livekit_test.go
Normal file
306
internal/engine/livekit/livekit_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package livekit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lksdk "github.com/livekit/server-sdk-go/v2"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/engine"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type fakeRoom struct {
|
||||
mu sync.Mutex
|
||||
state lksdk.ConnectionState
|
||||
published [][]byte
|
||||
tracks int
|
||||
unpublished int
|
||||
disconnected int
|
||||
}
|
||||
|
||||
func newFakeRoom() *fakeRoom {
|
||||
return &fakeRoom{state: lksdk.ConnectionStateConnected}
|
||||
}
|
||||
|
||||
func (r *fakeRoom) publishData(data []byte) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.published = append(r.published, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRoom) publishTrack(webrtc.TrackLocal) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tracks++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRoom) unpublishLocalTracks() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.unpublished++
|
||||
}
|
||||
|
||||
func (r *fakeRoom) disconnect() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.disconnected++
|
||||
r.state = lksdk.ConnectionStateDisconnected
|
||||
}
|
||||
|
||||
func (r *fakeRoom) connectionState() lksdk.ConnectionState {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.state
|
||||
}
|
||||
|
||||
type fakeConnector struct {
|
||||
mu sync.Mutex
|
||||
urls []string
|
||||
tokens []string
|
||||
callbacks []*lksdk.RoomCallback
|
||||
rooms []*fakeRoom
|
||||
connected chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newFakeConnector() *fakeConnector {
|
||||
return &fakeConnector{connected: make(chan struct{}, 8)}
|
||||
}
|
||||
|
||||
func (c *fakeConnector) connect(url, token string, cb *lksdk.RoomCallback) (roomHandle, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
room := newFakeRoom()
|
||||
c.urls = append(c.urls, url)
|
||||
c.tokens = append(c.tokens, token)
|
||||
c.callbacks = append(c.callbacks, cb)
|
||||
c.rooms = append(c.rooms, room)
|
||||
c.connected <- struct{}{}
|
||||
return room, nil
|
||||
}
|
||||
|
||||
func (c *fakeConnector) count() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.rooms)
|
||||
}
|
||||
|
||||
func (c *fakeConnector) callback(i int) *lksdk.RoomCallback {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.callbacks[i]
|
||||
}
|
||||
|
||||
func (c *fakeConnector) room(i int) *fakeRoom {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.rooms[i]
|
||||
}
|
||||
|
||||
func (c *fakeConnector) snapshot() ([]string, []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return append([]string(nil), c.urls...), append([]string(nil), c.tokens...)
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, cond func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("condition was not met before timeout")
|
||||
}
|
||||
|
||||
func TestReconnectRefreshesCredentialsAndReplacesRoom(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
refreshes := 0
|
||||
sess, err := New(ctx, engine.Config{
|
||||
URL: "wss://old",
|
||||
Token: "old-token",
|
||||
Refresh: func(context.Context) (engine.Credentials, error) {
|
||||
refreshes++
|
||||
return engine.Credentials{URL: "wss://new", Token: "new-token"}, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
s := sess.(*Session)
|
||||
connector := newFakeConnector()
|
||||
s.connectRoom = connector.connect
|
||||
|
||||
reconnected := make(chan struct{}, 1)
|
||||
s.SetReconnectCallback(func(*webrtc.DataChannel) {
|
||||
reconnected <- struct{}{}
|
||||
})
|
||||
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
go s.WatchConnection(ctx)
|
||||
|
||||
connector.callback(0).OnDisconnected()
|
||||
|
||||
waitFor(t, func() bool { return connector.count() == 2 })
|
||||
select {
|
||||
case <-reconnected:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reconnect callback was not called")
|
||||
}
|
||||
|
||||
urls, tokens := connector.snapshot()
|
||||
if got, want := urls, []string{"wss://old", "wss://new"}; !equalStrings(got, want) {
|
||||
t.Fatalf("connect urls = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := tokens, []string{"old-token", "new-token"}; !equalStrings(got, want) {
|
||||
t.Fatalf("connect tokens = %v, want %v", got, want)
|
||||
}
|
||||
if refreshes != 1 {
|
||||
t.Fatalf("refreshes = %d, want 1", refreshes)
|
||||
}
|
||||
oldRoom := connector.room(0)
|
||||
oldRoom.mu.Lock()
|
||||
if oldRoom.disconnected != 1 || oldRoom.unpublished != 1 {
|
||||
t.Fatalf("old room cleanup disconnected=%d unpublished=%d, want 1/1",
|
||||
oldRoom.disconnected, oldRoom.unpublished)
|
||||
}
|
||||
oldRoom.mu.Unlock()
|
||||
if !s.CanSend() {
|
||||
t.Fatal("CanSend() = false after reconnect, want true")
|
||||
}
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnectedEndsWhenReconnectDisallowed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sess, err := New(ctx, engine.Config{URL: "wss://old", Token: "old-token"})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
s := sess.(*Session)
|
||||
connector := newFakeConnector()
|
||||
s.connectRoom = connector.connect
|
||||
s.SetShouldReconnect(func() bool { return false })
|
||||
|
||||
ended := make(chan string, 1)
|
||||
s.SetEndedCallback(func(reason string) {
|
||||
ended <- reason
|
||||
})
|
||||
|
||||
if err := s.Connect(ctx); err != nil {
|
||||
t.Fatalf("Connect() error = %v", err)
|
||||
}
|
||||
connector.callback(0).OnDisconnected()
|
||||
|
||||
select {
|
||||
case reason := <-ended:
|
||||
if reason != "disconnected from livekit" {
|
||||
t.Fatalf("ended reason = %q, want disconnected from livekit", reason)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ended callback was not called")
|
||||
}
|
||||
if !s.closed.Load() {
|
||||
t.Fatal("closed = false after terminal disconnect")
|
||||
}
|
||||
if connector.count() != 1 {
|
||||
t.Fatalf("connect count = %d, want 1", connector.count())
|
||||
}
|
||||
room := connector.room(0)
|
||||
room.mu.Lock()
|
||||
if room.disconnected != 1 || room.unpublished != 1 {
|
||||
t.Fatalf("terminal room cleanup disconnected=%d unpublished=%d, want 1/1",
|
||||
room.disconnected, room.unpublished)
|
||||
}
|
||||
room.mu.Unlock()
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
room.mu.Lock()
|
||||
if room.disconnected != 1 || room.unpublished != 1 {
|
||||
t.Fatalf("second close cleanup disconnected=%d unpublished=%d, want still 1/1",
|
||||
room.disconnected, room.unpublished)
|
||||
}
|
||||
room.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCanSendRequiresConnectedRoomAndQueueHeadroom(t *testing.T) {
|
||||
s := &Session{
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true without room")
|
||||
}
|
||||
|
||||
room := newFakeRoom()
|
||||
room.state = lksdk.ConnectionStateDisconnected
|
||||
s.setRoom(room)
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true for disconnected room")
|
||||
}
|
||||
|
||||
room.state = lksdk.ConnectionStateConnected
|
||||
if !s.CanSend() {
|
||||
t.Fatal("CanSend() = false for connected room")
|
||||
}
|
||||
|
||||
for i := 0; i < defaultSendQueueCapHard; i++ {
|
||||
s.sendQueue <- []byte("x")
|
||||
}
|
||||
if s.CanSend() {
|
||||
t.Fatal("CanSend() = true at queue high watermark")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnectFailureRetriesUntilContextDone(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Session{
|
||||
url: "wss://old",
|
||||
token: "old-token",
|
||||
connectRoom: func(string, string, *lksdk.RoomCallback) (roomHandle, error) {
|
||||
cancel()
|
||||
return nil, errors.New("boom")
|
||||
},
|
||||
reconnectCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
sendQueue: make(chan []byte, defaultSendQueueSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
if terminal := s.handleReconnectAttempt(ctx); !terminal {
|
||||
t.Fatal("handleReconnectAttempt() = false after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -417,10 +417,7 @@ func (s *Session) waitForMediaReady(ctx context.Context, timeout time.Duration)
|
||||
}
|
||||
|
||||
func (s *Session) dialWebSocket() error {
|
||||
wsDialer := websocket.Dialer{
|
||||
NetDialContext: protect.DialContext,
|
||||
HandshakeTimeout: wsHandshakeTimeout,
|
||||
}
|
||||
wsDialer := protect.NewWebSocketDialer(wsHandshakeTimeout)
|
||||
|
||||
ws, resp, err := wsDialer.Dial(s.connectorURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
// │ │
|
||||
//
|
||||
// After the exchange the control stream stays open; tunnel traffic flows over
|
||||
// additional smux streams opened by the client. The control stream may carry
|
||||
// keepalives or future control messages.
|
||||
// additional smux streams opened by the client. The control stream then
|
||||
// carries ping/pong liveness and future control messages.
|
||||
//
|
||||
//nolint:tagliatelle // JSON keys are the stable wire protocol schema.
|
||||
package handshake
|
||||
|
||||
@@ -43,6 +43,7 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create transport for direct link: %w", err)
|
||||
@@ -79,3 +80,6 @@ func (d *directLink) WatchConnection(ctx context.Context) {
|
||||
d.transport.WatchConnection(ctx)
|
||||
}
|
||||
func (d *directLink) CanSend() bool { return d.transport.CanSend() }
|
||||
|
||||
// Features reports the direct link's underlying transport capabilities.
|
||||
func (d *directLink) Features() link.Features { return d.transport.Features() }
|
||||
|
||||
@@ -79,12 +79,14 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
VideoTileRS: 20,
|
||||
VP8FPS: 25,
|
||||
VP8BatchSize: 8,
|
||||
Traffic: transport.TrafficConfig{MaxPayloadSize: 4096},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
|
||||
if seen.DeviceID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 ||
|
||||
seen.Traffic.MaxPayloadSize != 4096 {
|
||||
t.Fatalf("forwarded config = %+v", seen)
|
||||
}
|
||||
|
||||
@@ -112,6 +114,9 @@ func TestNewForwardsConfigAndMethods(t *testing.T) {
|
||||
if !ln.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
}
|
||||
if features := ln.(link.FeaturesProvider).Features(); features.MaxPayloadSize != 4096 {
|
||||
t.Fatalf("Features() = %+v, want shaped max payload 4096", features)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWrapsFactoryError(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,8 @@ package link
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,11 +25,19 @@ type Link interface {
|
||||
CanSend() bool
|
||||
}
|
||||
|
||||
// Features mirrors the underlying transport capabilities when a link can expose them.
|
||||
type Features = transport.Features
|
||||
|
||||
// FeaturesProvider is optionally implemented by links that can report wire limits.
|
||||
type FeaturesProvider interface {
|
||||
Features() Features
|
||||
}
|
||||
|
||||
// Config holds common link configuration.
|
||||
type Config struct {
|
||||
Transport string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
Transport string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
// Engine, URL, Token are forwarded for the "none" auth carrier.
|
||||
Engine string
|
||||
URL string
|
||||
@@ -54,6 +64,7 @@ type Config struct {
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Traffic transport.TrafficConfig
|
||||
}
|
||||
|
||||
// Factory creates a link instance.
|
||||
|
||||
@@ -3,13 +3,38 @@ package protect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialTimeout = 10 * time.Second
|
||||
defaultKeepAlive = 30 * time.Second
|
||||
defaultIdleConnTimeout = 30 * time.Second
|
||||
defaultTLSHandshake = 10 * time.Second
|
||||
defaultResponseHeader = 10 * time.Second
|
||||
defaultWebSocketTimeout = 10 * time.Second
|
||||
defaultHTTPClientTimeout = 30 * time.Second
|
||||
defaultStatusBodyLimit = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
sensitiveFieldRE = regexp.MustCompile(
|
||||
`(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` +
|
||||
`[^",\s}]+`,
|
||||
)
|
||||
sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`)
|
||||
) //nolint:gochecknoglobals // compiled once for provider error redaction
|
||||
|
||||
// Protector is called with a socket file descriptor before connect.
|
||||
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
|
||||
var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional
|
||||
@@ -33,24 +58,70 @@ func controlFunc(network, _ string, c syscall.RawConn) error {
|
||||
// NewDialer returns a net.Dialer that calls Protector on each new socket.
|
||||
func NewDialer() *net.Dialer {
|
||||
return &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Timeout: defaultDialTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
Control: controlFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients.
|
||||
func NewTLSConfig() *tls.Config {
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts.
|
||||
func NewHTTPTransport() *http.Transport {
|
||||
dialer := NewDialer()
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
TLSClientConfig: NewTLSConfig(),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: defaultIdleConnTimeout,
|
||||
TLSHandshakeTimeout: defaultTLSHandshake,
|
||||
ResponseHeaderTimeout: defaultResponseHeader,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPClient returns an http.Client using protected sockets.
|
||||
func NewHTTPClient() *http.Client {
|
||||
dialer := NewDialer()
|
||||
transport := &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
return &http.Client{
|
||||
Transport: NewHTTPTransport(),
|
||||
Timeout: defaultHTTPClientTimeout,
|
||||
}
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
|
||||
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
|
||||
if handshakeTimeout <= 0 {
|
||||
handshakeTimeout = defaultWebSocketTimeout
|
||||
}
|
||||
return websocket.Dialer{
|
||||
NetDialContext: DialContext,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: NewTLSConfig(),
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// StatusError formats an upstream HTTP error while bounding and redacting the body.
|
||||
func StatusError(base error, resp *http.Response, limit int64) error {
|
||||
if limit <= 0 {
|
||||
limit = defaultStatusBodyLimit
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||
bodyText := RedactSensitive(strings.TrimSpace(string(body)))
|
||||
if bodyText == "" {
|
||||
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
|
||||
}
|
||||
|
||||
// RedactSensitive removes common token-like values from provider error text.
|
||||
func RedactSensitive(text string) string {
|
||||
text = sensitiveBearerRE.ReplaceAllString(text, "${1}<redacted>")
|
||||
return sensitiveFieldRE.ReplaceAllString(text, "${1}<redacted>")
|
||||
}
|
||||
|
||||
// DialContext dials using a protected socket.
|
||||
|
||||
@@ -2,9 +2,11 @@ package protect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -88,13 +90,57 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
|
||||
}
|
||||
if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
|
||||
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
|
||||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
|
||||
tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second ||
|
||||
tr.ResponseHeaderTimeout != 10*time.Second {
|
||||
tr.ResponseHeaderTimeout != 10*time.Second || client.Timeout != 30*time.Second {
|
||||
t.Fatalf("transport = %+v", tr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSocketDialer(t *testing.T) {
|
||||
dialer := NewWebSocketDialer(3 * time.Second)
|
||||
if dialer.NetDialContext == nil || dialer.Proxy == nil || dialer.TLSClientConfig == nil ||
|
||||
dialer.TLSClientConfig.MinVersion != tls.VersionTLS12 ||
|
||||
dialer.HandshakeTimeout != 3*time.Second {
|
||||
t.Fatalf("NewWebSocketDialer() = %+v", dialer)
|
||||
}
|
||||
|
||||
defaulted := NewWebSocketDialer(0)
|
||||
if defaulted.HandshakeTimeout != defaultWebSocketTimeout {
|
||||
t.Fatalf("default HandshakeTimeout = %v, want %v",
|
||||
defaulted.HandshakeTimeout, defaultWebSocketTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusErrorRedactsAndLimitsBody(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: ioNopCloser{strings.NewReader(`{"accessToken":"secret","message":"no"}`)},
|
||||
}
|
||||
err := StatusError(errProtectBoom, resp, 1024)
|
||||
if err == nil {
|
||||
t.Fatal("StatusError() error = nil")
|
||||
}
|
||||
text := err.Error()
|
||||
if strings.Contains(text, "secret") || !strings.Contains(text, "<redacted>") {
|
||||
t.Fatalf("StatusError() = %q, want redacted token", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactSensitiveBearer(t *testing.T) {
|
||||
got := RedactSensitive("Authorization: Bearer abc.def")
|
||||
if strings.Contains(got, "abc.def") || !strings.Contains(got, "Bearer <redacted>") {
|
||||
t.Fatalf("RedactSensitive() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type ioNopCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (c ioNopCloser) Close() error { return nil }
|
||||
|
||||
func TestDialContextAndProxyDialer(t *testing.T) {
|
||||
var lc net.ListenConfig
|
||||
ln, err := lc.Listen(context.Background(), "tcp4", "127.0.0.1:0")
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
@@ -49,25 +51,33 @@ type SessionCloseFunc func(sessionID, reason string)
|
||||
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
|
||||
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
|
||||
|
||||
// HealthFunc is called when the server control health snapshot changes.
|
||||
type HealthFunc func(control.Status)
|
||||
|
||||
// Server handles incoming tunnel connections and proxies their traffic.
|
||||
type Server struct {
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
controlStop context.CancelFunc
|
||||
sessMu sync.RWMutex
|
||||
reinstallMu sync.Mutex
|
||||
healthMu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
authHook handshake.AuthFunc
|
||||
onOpen SessionOpenFunc
|
||||
onClose SessionCloseFunc
|
||||
onTraffic TrafficFunc
|
||||
onHealth HealthFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
socksProxyPort int
|
||||
liveness control.Config
|
||||
health control.Status
|
||||
}
|
||||
|
||||
// ConnectRequest is a message from the client to establish a new connection.
|
||||
@@ -106,6 +116,8 @@ type Config struct {
|
||||
Engine string
|
||||
URL string
|
||||
Token string
|
||||
Liveness control.Config
|
||||
Traffic transport.TrafficConfig
|
||||
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
@@ -117,6 +129,8 @@ type Config struct {
|
||||
OnSessionClose SessionCloseFunc
|
||||
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
|
||||
OnTraffic TrafficFunc
|
||||
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
|
||||
OnHealth HealthFunc
|
||||
}
|
||||
|
||||
// Run starts the server with the given configuration.
|
||||
@@ -145,6 +159,10 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
if onTraffic == nil {
|
||||
onTraffic = func(string, string, uint64, uint64) {}
|
||||
}
|
||||
onHealth := cfg.OnHealth
|
||||
if onHealth == nil {
|
||||
onHealth = func(control.Status) {}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
@@ -152,9 +170,11 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
onOpen: onOpen,
|
||||
onClose: onClose,
|
||||
onTraffic: onTraffic,
|
||||
onHealth: onHealth,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
liveness: cfg.Liveness,
|
||||
}
|
||||
s.setupResolver()
|
||||
|
||||
@@ -216,11 +236,17 @@ func (s *Server) setupResolver() {
|
||||
|
||||
// smuxConfig mirrors the client side. Both peers must agree on Version and
|
||||
// MaxFrameSize.
|
||||
func smuxConfig() *smux.Config {
|
||||
func smuxConfig(maxWirePayload ...int) *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.KeepAliveDisabled = true
|
||||
cfg.MaxFrameSize = 32768
|
||||
if len(maxWirePayload) > 0 && maxWirePayload[0] > crypto.WireOverhead {
|
||||
maxFrameSize := maxWirePayload[0] - crypto.WireOverhead
|
||||
if maxFrameSize < cfg.MaxFrameSize {
|
||||
cfg.MaxFrameSize = maxFrameSize
|
||||
}
|
||||
}
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
@@ -228,6 +254,14 @@ func smuxConfig() *smux.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func linkMaxPayload(ln link.Link) int {
|
||||
provider, ok := ln.(link.FeaturesProvider)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return provider.Features().MaxPayloadSize
|
||||
}
|
||||
|
||||
func (s *Server) bringUpLink(
|
||||
ctx context.Context,
|
||||
cfg Config,
|
||||
@@ -262,6 +296,7 @@ func (s *Server) bringUpLink(
|
||||
SEIBatchSize: cfg.SEIBatchSize,
|
||||
SEIFragmentSize: cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: cfg.SEIAckTimeoutMS,
|
||||
Traffic: cfg.Traffic,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
@@ -298,7 +333,7 @@ func (s *Server) bringUpLink(
|
||||
|
||||
func (s *Server) installSession() {
|
||||
conn := muxconn.New(s.ln, s.cipher)
|
||||
sess, err := smux.Server(conn, smuxConfig())
|
||||
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
return
|
||||
@@ -310,7 +345,8 @@ func (s *Server) installSession() {
|
||||
}
|
||||
|
||||
func (s *Server) handleReconnect() {
|
||||
logger.Infof("server link reconnect - tearing down smux session")
|
||||
s.recordReconnect()
|
||||
logger.Infof("server reconnect reason=carrier - tearing down smux session")
|
||||
s.sessMu.RLock()
|
||||
current := s.session
|
||||
s.sessMu.RUnlock()
|
||||
@@ -323,7 +359,7 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
|
||||
// Pre-build the replacement so we can swap atomically below.
|
||||
newConn := muxconn.New(s.ln, s.cipher)
|
||||
newSess, err := smux.Server(newConn, smuxConfig())
|
||||
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
_ = newConn.Close()
|
||||
@@ -340,13 +376,18 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
}
|
||||
oldSess := s.session
|
||||
oldConn := s.conn
|
||||
oldControlStop := s.controlStop
|
||||
oldSID := s.sessionID
|
||||
s.session = newSess
|
||||
s.conn = newConn
|
||||
s.controlStop = nil
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if oldControlStop != nil {
|
||||
oldControlStop()
|
||||
}
|
||||
if oldSess != nil {
|
||||
_ = oldSess.Close()
|
||||
}
|
||||
@@ -362,13 +403,18 @@ func (s *Server) closeSession() {
|
||||
s.sessMu.Lock()
|
||||
sess := s.session
|
||||
conn := s.conn
|
||||
controlStop := s.controlStop
|
||||
s.session = nil
|
||||
s.conn = nil
|
||||
s.controlStop = nil
|
||||
oldSID := s.sessionID
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
|
||||
if controlStop != nil {
|
||||
controlStop()
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -476,27 +522,120 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
s.deviceID = hello.DeviceID
|
||||
s.sessionID = sid
|
||||
s.sessMu.Unlock()
|
||||
s.recordSession(sid)
|
||||
s.onOpen(sid, hello.DeviceID, hello.Claims)
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.parkControlStream(stream)
|
||||
}()
|
||||
s.startControlLoop(ctx, sess, stream)
|
||||
return true
|
||||
}
|
||||
|
||||
// parkControlStream blocks reading from the control stream until it closes.
|
||||
// Future control messages (kick, rate updates, etc.) would be dispatched here.
|
||||
func (s *Server) parkControlStream(stream *smux.Stream) {
|
||||
defer func() { _ = stream.Close() }()
|
||||
buf := make([]byte, 64)
|
||||
for {
|
||||
if _, err := stream.Read(buf); err != nil {
|
||||
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
|
||||
controlCtx, stop := context.WithCancel(ctx)
|
||||
s.sessMu.Lock()
|
||||
s.controlStop = stop
|
||||
s.sessMu.Unlock()
|
||||
|
||||
liveness := s.liveness
|
||||
onPong := liveness.OnPong
|
||||
onMissedPong := liveness.OnMissedPong
|
||||
onUnhealthy := liveness.OnUnhealthy
|
||||
liveness.OnPong = func(h control.Health) {
|
||||
s.sessMu.RLock()
|
||||
sid := s.sessionID
|
||||
s.sessMu.RUnlock()
|
||||
s.recordPong(h)
|
||||
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
||||
if onPong != nil {
|
||||
onPong(h)
|
||||
}
|
||||
}
|
||||
liveness.OnMissedPong = func(missed int) {
|
||||
s.recordMissed(missed)
|
||||
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
|
||||
if onMissedPong != nil {
|
||||
onMissedPong(missed)
|
||||
}
|
||||
}
|
||||
liveness.OnUnhealthy = func(missed int) {
|
||||
s.recordUnhealthy(missed)
|
||||
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
|
||||
if onUnhealthy != nil {
|
||||
onUnhealthy(missed)
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer func() { _ = stream.Close() }()
|
||||
err := control.Run(controlCtx, stream, liveness)
|
||||
if controlCtx.Err() != nil || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warnf("server control stream ended: %v", err)
|
||||
}
|
||||
s.recordReconnect()
|
||||
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
|
||||
s.reinstallSession(sess)
|
||||
}()
|
||||
}
|
||||
|
||||
// Status returns the latest server-side control health snapshot.
|
||||
func (s *Server) Status() control.Status {
|
||||
s.healthMu.RLock()
|
||||
defer s.healthMu.RUnlock()
|
||||
return s.health
|
||||
}
|
||||
|
||||
func (s *Server) recordSession(sessionID string) {
|
||||
s.healthMu.Lock()
|
||||
s.health.SessionID = sessionID
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordPong(h control.Health) {
|
||||
s.healthMu.Lock()
|
||||
s.health.LastPong = h.LastSeen
|
||||
s.health.LastRTT = h.RTT
|
||||
s.health.MissedPongs = 0
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordMissed(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordUnhealthy(missed int) {
|
||||
s.healthMu.Lock()
|
||||
s.health.MissedPongs = missed
|
||||
s.health.UnhealthyEvents++
|
||||
s.health.LastUnhealthy = time.Now()
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) recordReconnect() {
|
||||
s.healthMu.Lock()
|
||||
s.health.Reconnects++
|
||||
status := s.health
|
||||
s.healthMu.Unlock()
|
||||
s.notifyHealth(status)
|
||||
}
|
||||
|
||||
func (s *Server) notifyHealth(status control.Status) {
|
||||
if s.onHealth != nil {
|
||||
s.onHealth(status)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/control"
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/xtaci/smux"
|
||||
@@ -49,6 +50,11 @@ func TestSmuxConfig(t *testing.T) {
|
||||
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
|
||||
t.Fatalf("smuxConfig() = %+v", cfg)
|
||||
}
|
||||
capped := smuxConfig(4096)
|
||||
if capped.MaxFrameSize != 4096-cryptopkg.WireOverhead {
|
||||
t.Fatalf("smuxConfig(4096).MaxFrameSize = %d, want %d",
|
||||
capped.MaxFrameSize, 4096-cryptopkg.WireOverhead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConnectRequest(t *testing.T) {
|
||||
@@ -373,6 +379,103 @@ func TestReinstallSessionFiresOnClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartControlLoopReportsPong(t *testing.T) {
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
serverStreamCh := make(chan *smux.Stream, 1)
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err == nil {
|
||||
serverStreamCh <- stream
|
||||
}
|
||||
}()
|
||||
|
||||
clientStream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
serverStream := <-serverStreamCh
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
got := make(chan control.Health, 1)
|
||||
s := &Server{
|
||||
sessionID: "sid-control",
|
||||
liveness: control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
OnPong: func(h control.Health) {
|
||||
select {
|
||||
case got <- h:
|
||||
default:
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
s.recordSession("sid-control")
|
||||
defer func() {
|
||||
cancel()
|
||||
s.wg.Wait()
|
||||
}()
|
||||
s.startControlLoop(ctx, serverSess, serverStream)
|
||||
go func() {
|
||||
_ = control.Run(ctx, clientStream, control.Config{
|
||||
Interval: 10 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
Failures: 2,
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case h := <-got:
|
||||
if h.Seq == 0 {
|
||||
t.Fatal("Health.Seq = 0")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for control pong")
|
||||
}
|
||||
status := s.Status()
|
||||
if status.SessionID != "sid-control" {
|
||||
t.Fatalf("Status.SessionID = %q, want sid-control", status.SessionID)
|
||||
}
|
||||
if status.LastPong.IsZero() || status.LastRTT < 0 || status.MissedPongs != 0 {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
|
||||
updates := 0
|
||||
s := &Server{onHealth: func(control.Status) { updates++ }}
|
||||
s.recordSession("sid-1")
|
||||
s.recordMissed(2)
|
||||
s.recordUnhealthy(3)
|
||||
s.recordReconnect()
|
||||
|
||||
status := s.Status()
|
||||
if status.SessionID != "sid-1" || status.MissedPongs != 3 ||
|
||||
status.UnhealthyEvents != 1 || status.Reconnects != 1 || status.LastUnhealthy.IsZero() {
|
||||
t.Fatalf("Status() = %+v", status)
|
||||
}
|
||||
if updates != 4 {
|
||||
t.Fatalf("health updates = %d, want 4", updates)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // integration-style test needs setup, proxying, and traffic assertions together.
|
||||
func TestDispatchFiresOnTraffic(t *testing.T) {
|
||||
var lc net.ListenConfig
|
||||
|
||||
229
internal/supervisor/supervisor.go
Normal file
229
internal/supervisor/supervisor.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Package supervisor runs ordered session profiles with failover.
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
)
|
||||
|
||||
const DefaultRetryDelay = 2 * time.Second
|
||||
const DefaultHistoryLimit = 20
|
||||
|
||||
const (
|
||||
// EventProfileStart marks a profile attempt starting.
|
||||
EventProfileStart = "profile_start"
|
||||
// EventProfileEnd marks a profile attempt ending.
|
||||
EventProfileEnd = "profile_end"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoProfiles is returned when the supervisor is started without profiles.
|
||||
ErrNoProfiles = errors.New("supervisor: no profiles configured")
|
||||
// ErrMaxCyclesExceeded is returned after MaxCycles complete profile-list passes.
|
||||
ErrMaxCyclesExceeded = errors.New("supervisor: max failover cycles exceeded")
|
||||
)
|
||||
|
||||
// Profile is one runnable session configuration in an ordered failover list.
|
||||
type Profile struct {
|
||||
Name string
|
||||
Config session.Config
|
||||
}
|
||||
|
||||
// ProfileStatus summarizes one profile's failover history.
|
||||
type ProfileStatus struct {
|
||||
Name string
|
||||
Starts int
|
||||
Failures int
|
||||
CleanEnds int
|
||||
LastStarted time.Time
|
||||
LastEnded time.Time
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Event is one bounded failover history entry.
|
||||
type Event struct {
|
||||
Time time.Time
|
||||
Type string
|
||||
Profile string
|
||||
Cycle int
|
||||
Error string
|
||||
}
|
||||
|
||||
// Status is a point-in-time view of the supervisor.
|
||||
type Status struct {
|
||||
Cycle int
|
||||
ActiveProfile string
|
||||
ActiveProfileIndex int
|
||||
Profiles []ProfileStatus
|
||||
History []Event
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Runner starts one session profile and blocks until it ends or fails.
|
||||
type Runner func(ctx context.Context, cfg session.Config) error
|
||||
|
||||
// Config controls ordered failover behavior.
|
||||
type Config struct {
|
||||
Profiles []Profile
|
||||
RetryDelay time.Duration
|
||||
MaxCycles int
|
||||
|
||||
OnProfileStart func(profile Profile, cycle int)
|
||||
OnProfileEnd func(profile Profile, cycle int, err error)
|
||||
OnStatus func(status Status)
|
||||
HistoryLimit int
|
||||
}
|
||||
|
||||
// Run starts profiles in order. If a profile exits while ctx is still active,
|
||||
// the supervisor waits RetryDelay and advances to the next profile.
|
||||
func Run(ctx context.Context, cfg Config, run Runner) error {
|
||||
if len(cfg.Profiles) == 0 {
|
||||
return ErrNoProfiles
|
||||
}
|
||||
if cfg.RetryDelay == 0 {
|
||||
cfg.RetryDelay = DefaultRetryDelay
|
||||
}
|
||||
state := newStatusTracker(cfg.Profiles, cfg.HistoryLimit, cfg.OnStatus)
|
||||
|
||||
var lastErr error
|
||||
for cycle := 1; ; cycle++ {
|
||||
for i, profile := range cfg.Profiles {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
state.start(i, cycle)
|
||||
if cfg.OnProfileStart != nil {
|
||||
cfg.OnProfileStart(profile, cycle)
|
||||
}
|
||||
|
||||
err := run(ctx, profile.Config)
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("profile %q: %w", profile.Name, err)
|
||||
} else {
|
||||
lastErr = fmt.Errorf("profile %q ended", profile.Name)
|
||||
}
|
||||
state.end(i, cycle, err)
|
||||
if cfg.OnProfileEnd != nil {
|
||||
cfg.OnProfileEnd(profile, cycle, err)
|
||||
}
|
||||
|
||||
if cfg.MaxCycles > 0 && cycle >= cfg.MaxCycles && i == len(cfg.Profiles)-1 {
|
||||
return fmt.Errorf("%w after %d cycle(s): %w", ErrMaxCyclesExceeded, cycle, lastErr)
|
||||
}
|
||||
if err := waitRetryDelay(ctx, cfg.RetryDelay); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type statusTracker struct {
|
||||
status Status
|
||||
notify func(Status)
|
||||
historyLimit int
|
||||
}
|
||||
|
||||
func newStatusTracker(profiles []Profile, historyLimit int, notify func(Status)) *statusTracker {
|
||||
if historyLimit == 0 {
|
||||
historyLimit = DefaultHistoryLimit
|
||||
}
|
||||
statusProfiles := make([]ProfileStatus, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
statusProfiles = append(statusProfiles, ProfileStatus{Name: profile.Name})
|
||||
}
|
||||
return &statusTracker{
|
||||
status: Status{
|
||||
ActiveProfileIndex: -1,
|
||||
Profiles: statusProfiles,
|
||||
},
|
||||
notify: notify,
|
||||
historyLimit: historyLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *statusTracker) start(profileIndex, cycle int) {
|
||||
now := time.Now()
|
||||
profile := &t.status.Profiles[profileIndex]
|
||||
profile.Starts++
|
||||
profile.LastStarted = now
|
||||
t.status.Cycle = cycle
|
||||
t.status.ActiveProfile = profile.Name
|
||||
t.status.ActiveProfileIndex = profileIndex
|
||||
t.appendHistory(Event{
|
||||
Time: now,
|
||||
Type: EventProfileStart,
|
||||
Profile: profile.Name,
|
||||
Cycle: cycle,
|
||||
})
|
||||
t.emit()
|
||||
}
|
||||
|
||||
func (t *statusTracker) end(profileIndex, cycle int, err error) {
|
||||
now := time.Now()
|
||||
profile := &t.status.Profiles[profileIndex]
|
||||
profile.LastEnded = now
|
||||
event := Event{
|
||||
Time: now,
|
||||
Type: EventProfileEnd,
|
||||
Profile: profile.Name,
|
||||
Cycle: cycle,
|
||||
}
|
||||
if err != nil {
|
||||
profile.Failures++
|
||||
profile.LastError = err.Error()
|
||||
t.status.LastError = fmt.Sprintf("profile %q: %v", profile.Name, err)
|
||||
event.Error = err.Error()
|
||||
} else {
|
||||
profile.CleanEnds++
|
||||
profile.LastError = ""
|
||||
t.status.LastError = fmt.Sprintf("profile %q ended", profile.Name)
|
||||
}
|
||||
t.status.ActiveProfile = ""
|
||||
t.status.ActiveProfileIndex = -1
|
||||
t.appendHistory(event)
|
||||
t.emit()
|
||||
}
|
||||
|
||||
func (t *statusTracker) appendHistory(event Event) {
|
||||
if t.historyLimit < 0 {
|
||||
return
|
||||
}
|
||||
t.status.History = append(t.status.History, event)
|
||||
if len(t.status.History) > t.historyLimit {
|
||||
t.status.History = t.status.History[len(t.status.History)-t.historyLimit:]
|
||||
}
|
||||
}
|
||||
|
||||
func (t *statusTracker) emit() {
|
||||
if t.notify == nil {
|
||||
return
|
||||
}
|
||||
t.notify(cloneStatus(t.status))
|
||||
}
|
||||
|
||||
func cloneStatus(status Status) Status {
|
||||
status.Profiles = append([]ProfileStatus(nil), status.Profiles...)
|
||||
status.History = append([]Event(nil), status.History...)
|
||||
return status
|
||||
}
|
||||
|
||||
func waitRetryDelay(ctx context.Context, delay time.Duration) error {
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
170
internal/supervisor/supervisor_test.go
Normal file
170
internal/supervisor/supervisor_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
)
|
||||
|
||||
var errRunnerBoom = errors.New("boom")
|
||||
|
||||
func TestRunRequiresProfiles(t *testing.T) {
|
||||
err := Run(context.Background(), Config{}, func(context.Context, session.Config) error { return nil })
|
||||
if !errors.Is(err, ErrNoProfiles) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrNoProfiles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAdvancesProfilesAndStopsAtMaxCycles(t *testing.T) {
|
||||
profiles := []Profile{
|
||||
{Name: "first", Config: session.Config{Auth: "wbstream"}},
|
||||
{Name: "second", Config: session.Config{Auth: "jitsi"}},
|
||||
}
|
||||
var started []string
|
||||
var ended []string
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
OnProfileStart: func(profile Profile, cycle int) {
|
||||
started = append(started, profile.Name)
|
||||
if cycle != 1 {
|
||||
t.Fatalf("cycle = %d, want 1", cycle)
|
||||
}
|
||||
},
|
||||
OnProfileEnd: func(profile Profile, _ int, err error) {
|
||||
ended = append(ended, profile.Name)
|
||||
if !errors.Is(err, errRunnerBoom) {
|
||||
t.Fatalf("profile %s err = %v, want %v", profile.Name, err, errRunnerBoom)
|
||||
}
|
||||
},
|
||||
}, func(_ context.Context, cfg session.Config) error {
|
||||
if cfg.Auth == "" {
|
||||
t.Fatal("runner received empty auth")
|
||||
}
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if got, want := started, []string{"first", "second"}; !equalStrings(got, want) {
|
||||
t.Fatalf("started = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := ended, []string{"first", "second"}; !equalStrings(got, want) {
|
||||
t.Fatalf("ended = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmitsStatusHistory(t *testing.T) {
|
||||
profiles := []Profile{
|
||||
{Name: "first", Config: session.Config{Auth: "wbstream"}},
|
||||
{Name: "second", Config: session.Config{Auth: "jitsi"}},
|
||||
}
|
||||
var snapshots []Status
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: profiles,
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
HistoryLimit: 3,
|
||||
OnStatus: func(status Status) {
|
||||
snapshots = append(snapshots, status)
|
||||
},
|
||||
}, func(_ context.Context, cfg session.Config) error {
|
||||
if cfg.Auth == "first" {
|
||||
t.Fatal("runner received profile name instead of config")
|
||||
}
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if len(snapshots) != 4 {
|
||||
t.Fatalf("status snapshots = %d, want 4", len(snapshots))
|
||||
}
|
||||
first := snapshots[0]
|
||||
if first.ActiveProfile != "first" || first.ActiveProfileIndex != 0 || first.Cycle != 1 {
|
||||
t.Fatalf("first status = %+v", first)
|
||||
}
|
||||
if first.Profiles[0].Starts != 1 || first.Profiles[0].LastStarted.IsZero() {
|
||||
t.Fatalf("first profile start status = %+v", first.Profiles[0])
|
||||
}
|
||||
last := snapshots[len(snapshots)-1]
|
||||
if last.ActiveProfile != "" || last.ActiveProfileIndex != -1 {
|
||||
t.Fatalf("last active status = %+v", last)
|
||||
}
|
||||
if last.Profiles[0].Failures != 1 || last.Profiles[1].Failures != 1 {
|
||||
t.Fatalf("profile failures = %+v", last.Profiles)
|
||||
}
|
||||
if last.LastError == "" || last.Profiles[1].LastError == "" {
|
||||
t.Fatalf("last errors missing: %+v", last)
|
||||
}
|
||||
if len(last.History) != 3 {
|
||||
t.Fatalf("history length = %d, want 3", len(last.History))
|
||||
}
|
||||
if last.History[0].Type != EventProfileEnd || last.History[0].Profile != "first" {
|
||||
t.Fatalf("oldest bounded history event = %+v", last.History[0])
|
||||
}
|
||||
if last.History[2].Type != EventProfileEnd || last.History[2].Profile != "second" ||
|
||||
last.History[2].Error == "" {
|
||||
t.Fatalf("last history event = %+v", last.History[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStatusSnapshotIsImmutable(t *testing.T) {
|
||||
var first Status
|
||||
var second Status
|
||||
err := Run(context.Background(), Config{
|
||||
Profiles: []Profile{{Name: "one"}},
|
||||
RetryDelay: -1,
|
||||
MaxCycles: 1,
|
||||
OnStatus: func(status Status) {
|
||||
if first.Profiles == nil {
|
||||
first = status
|
||||
first.Profiles[0].Starts = 99
|
||||
first.History[0].Profile = "mutated"
|
||||
return
|
||||
}
|
||||
second = status
|
||||
},
|
||||
}, func(context.Context, session.Config) error {
|
||||
return errRunnerBoom
|
||||
})
|
||||
if !errors.Is(err, ErrMaxCyclesExceeded) {
|
||||
t.Fatalf("Run() error = %v, want %v", err, ErrMaxCyclesExceeded)
|
||||
}
|
||||
if first.Profiles[0].Starts != 99 || first.History[0].Profile != "mutated" {
|
||||
t.Fatalf("test mutation did not apply to snapshot: %+v", first)
|
||||
}
|
||||
if second.Profiles[0].Starts != 1 || second.History[0].Profile != "one" {
|
||||
t.Fatalf("snapshot mutation leaked into later status: %+v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunReturnsNilOnContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
err := Run(ctx, Config{
|
||||
Profiles: []Profile{{Name: "one"}},
|
||||
RetryDelay: time.Hour,
|
||||
}, func(context.Context, session.Config) error {
|
||||
cancel()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func equalStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
protocolVersion byte = 1
|
||||
frameTypeData byte = 1
|
||||
frameTypeAck byte = 2
|
||||
frameTypeHello byte = 3
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -86,6 +87,7 @@ type streamTransport struct {
|
||||
nextSeq atomic.Uint32
|
||||
closed atomic.Bool
|
||||
writerUp atomic.Bool
|
||||
peerReady atomic.Bool
|
||||
sendMu sync.Mutex
|
||||
startWriter sync.Once
|
||||
ackMu sync.Mutex
|
||||
@@ -286,7 +288,7 @@ func (p *streamTransport) WatchConnection(ctx context.Context) {
|
||||
|
||||
// CanSend reports whether transport is ready for sending.
|
||||
func (p *streamTransport) CanSend() bool {
|
||||
return !p.closed.Load() && p.stream.CanSend()
|
||||
return !p.closed.Load() && p.peerReady.Load() && p.stream.CanSend()
|
||||
}
|
||||
|
||||
// Features describes the current seichannel transport semantics.
|
||||
@@ -333,7 +335,7 @@ func (p *streamTransport) writerLoop() {
|
||||
ticker := time.NewTicker(p.effectiveFrameInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
idle := buildVideoAccessUnit(nil)
|
||||
idle := buildVideoAccessUnit(encodeHelloFrame())
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -443,9 +445,13 @@ func (p *streamTransport) handleSample(sample []byte) {
|
||||
}
|
||||
|
||||
switch frame.typ {
|
||||
case frameTypeHello:
|
||||
p.peerReady.Store(true)
|
||||
case frameTypeAck:
|
||||
p.peerReady.Store(true)
|
||||
p.resolveAck(frame.seq, frame.crc)
|
||||
case frameTypeData:
|
||||
p.peerReady.Store(true)
|
||||
p.handleInboundFrame(frame)
|
||||
}
|
||||
}
|
||||
@@ -562,8 +568,8 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
|
||||
out[5] = frameTypeData
|
||||
binary.BigEndian.PutUint32(out[6:10], seq)
|
||||
binary.BigEndian.PutUint32(out[10:14], crc)
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
||||
copy(out[22:], payload)
|
||||
return out
|
||||
@@ -579,6 +585,14 @@ func encodeAckFrame(seq, crc uint32) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeHelloFrame() []byte {
|
||||
out := make([]byte, 6)
|
||||
binary.BigEndian.PutUint32(out[0:4], protocolMagic)
|
||||
out[4] = protocolVersion
|
||||
out[5] = frameTypeHello
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
if len(data) < 6 {
|
||||
return transportFrame{}, ErrFrameTooShort
|
||||
@@ -592,6 +606,8 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
|
||||
frame := transportFrame{typ: data[5]}
|
||||
switch frame.typ {
|
||||
case frameTypeHello:
|
||||
return frame, nil
|
||||
case frameTypeAck:
|
||||
if len(data) < 14 {
|
||||
return transportFrame{}, ErrAckTooShort
|
||||
|
||||
@@ -78,3 +78,13 @@ func TestTransportFrameRoundTrip(t *testing.T) {
|
||||
t.Fatalf("payload mismatch: got=%q", decoded.payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelloFrameRoundTrip(t *testing.T) {
|
||||
hello, err := decodeTransportFrame(encodeHelloFrame())
|
||||
if err != nil {
|
||||
t.Fatalf("decodeTransportFrame(hello) failed: %v", err)
|
||||
}
|
||||
if hello.typ != frameTypeHello {
|
||||
t.Fatalf("hello frame type = %d, want %d", hello.typ, frameTypeHello)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,8 +103,12 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) {
|
||||
if stream.reconnect == nil || stream.should == nil || stream.ended == nil || !stream.watched {
|
||||
t.Fatal("callbacks/watch were not forwarded")
|
||||
}
|
||||
if tr.CanSend() {
|
||||
t.Fatal("CanSend() = true before peer hello")
|
||||
}
|
||||
tr.handleSample(buildVideoAccessUnit(encodeHelloFrame()))
|
||||
if !tr.CanSend() {
|
||||
t.Fatal("CanSend() = false, want true")
|
||||
t.Fatal("CanSend() = false after peer hello")
|
||||
}
|
||||
if features := tr.Features(); !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize == 0 { //nolint:lll // long test description
|
||||
t.Fatalf("Features() = %+v", features)
|
||||
|
||||
91
internal/transport/traffic.go
Normal file
91
internal/transport/traffic.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrTrafficPayloadTooLarge = errors.New("traffic payload exceeds max_payload_size")
|
||||
|
||||
type trafficTransport struct {
|
||||
inner Transport
|
||||
maxPayloadSize int
|
||||
minDelay time.Duration
|
||||
maxDelay time.Duration
|
||||
sendMu sync.Mutex
|
||||
}
|
||||
|
||||
// WithTraffic wraps tr with optional payload caps and send pacing.
|
||||
func WithTraffic(tr Transport, cfg TrafficConfig) Transport {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
cfg = effectiveTrafficConfig(tr.Features(), cfg)
|
||||
if cfg.MaxPayloadSize <= 0 && cfg.MinDelay <= 0 && cfg.MaxDelay <= 0 {
|
||||
return tr
|
||||
}
|
||||
return &trafficTransport{
|
||||
inner: tr,
|
||||
maxPayloadSize: cfg.MaxPayloadSize,
|
||||
minDelay: cfg.MinDelay,
|
||||
maxDelay: cfg.MaxDelay,
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveTrafficConfig(features Features, cfg TrafficConfig) TrafficConfig {
|
||||
if cfg.MaxPayloadSize > 0 && features.MaxPayloadSize > 0 && features.MaxPayloadSize < cfg.MaxPayloadSize {
|
||||
cfg.MaxPayloadSize = features.MaxPayloadSize
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (t *trafficTransport) Connect(ctx context.Context) error { return t.inner.Connect(ctx) }
|
||||
|
||||
func (t *trafficTransport) Send(data []byte) error {
|
||||
t.sendMu.Lock()
|
||||
defer t.sendMu.Unlock()
|
||||
if t.maxPayloadSize > 0 && len(data) > t.maxPayloadSize {
|
||||
return fmt.Errorf("%w: size=%d max=%d", ErrTrafficPayloadTooLarge, len(data), t.maxPayloadSize)
|
||||
}
|
||||
if delay := t.nextDelay(); delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
return t.inner.Send(data)
|
||||
}
|
||||
|
||||
func (t *trafficTransport) Close() error { return t.inner.Close() }
|
||||
|
||||
func (t *trafficTransport) SetReconnectCallback(cb func()) { t.inner.SetReconnectCallback(cb) }
|
||||
|
||||
func (t *trafficTransport) SetShouldReconnect(fn func() bool) { t.inner.SetShouldReconnect(fn) }
|
||||
|
||||
func (t *trafficTransport) SetEndedCallback(cb func(string)) { t.inner.SetEndedCallback(cb) }
|
||||
|
||||
func (t *trafficTransport) WatchConnection(ctx context.Context) { t.inner.WatchConnection(ctx) }
|
||||
|
||||
func (t *trafficTransport) CanSend() bool { return t.inner.CanSend() }
|
||||
|
||||
func (t *trafficTransport) Features() Features {
|
||||
features := t.inner.Features()
|
||||
if t.maxPayloadSize > 0 &&
|
||||
(features.MaxPayloadSize == 0 || t.maxPayloadSize < features.MaxPayloadSize) {
|
||||
features.MaxPayloadSize = t.maxPayloadSize
|
||||
}
|
||||
return features
|
||||
}
|
||||
|
||||
func (t *trafficTransport) nextDelay() time.Duration {
|
||||
if t.maxDelay <= 0 && t.minDelay <= 0 {
|
||||
return 0
|
||||
}
|
||||
minDelay := t.minDelay
|
||||
maxDelay := t.maxDelay
|
||||
if maxDelay <= minDelay {
|
||||
return minDelay
|
||||
}
|
||||
return minDelay + time.Duration(rand.Int64N(int64(maxDelay-minDelay))) //nolint:gosec,lll // G404: non-cryptographic pacing jitter
|
||||
}
|
||||
67
internal/transport/traffic_test.go
Normal file
67
internal/transport/traffic_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type trafficStubTransport struct {
|
||||
features Features
|
||||
sent [][]byte
|
||||
}
|
||||
|
||||
func (s *trafficStubTransport) Connect(context.Context) error { return nil }
|
||||
func (s *trafficStubTransport) Send(data []byte) error {
|
||||
s.sent = append(s.sent, append([]byte(nil), data...))
|
||||
return nil
|
||||
}
|
||||
func (s *trafficStubTransport) Close() error { return nil }
|
||||
func (s *trafficStubTransport) SetReconnectCallback(func()) {}
|
||||
func (s *trafficStubTransport) SetShouldReconnect(func() bool) {}
|
||||
func (s *trafficStubTransport) SetEndedCallback(func(string)) {}
|
||||
func (s *trafficStubTransport) WatchConnection(context.Context) {}
|
||||
func (s *trafficStubTransport) CanSend() bool { return true }
|
||||
func (s *trafficStubTransport) Features() Features { return s.features }
|
||||
|
||||
func TestWithTrafficReturnsInnerWhenDisabled(t *testing.T) {
|
||||
inner := &trafficStubTransport{}
|
||||
got := WithTraffic(inner, TrafficConfig{})
|
||||
if got != inner {
|
||||
t.Fatalf("WithTraffic disabled returned %T, want inner", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrafficWrapperRejectsOversizedPayloadAndClampsFeatures(t *testing.T) {
|
||||
inner := &trafficStubTransport{features: Features{MaxPayloadSize: 5}}
|
||||
tr := WithTraffic(inner, TrafficConfig{MaxPayloadSize: 10})
|
||||
if features := tr.Features(); features.MaxPayloadSize != 5 {
|
||||
t.Fatalf("Features().MaxPayloadSize = %d, want 5", features.MaxPayloadSize)
|
||||
}
|
||||
err := tr.Send([]byte("123456"))
|
||||
if !errors.Is(err, ErrTrafficPayloadTooLarge) {
|
||||
t.Fatalf("Send() error = %v, want %v", err, ErrTrafficPayloadTooLarge)
|
||||
}
|
||||
if len(inner.sent) != 0 {
|
||||
t.Fatalf("inner sent %d payloads, want 0", len(inner.sent))
|
||||
}
|
||||
if err := tr.Send([]byte("12345")); err != nil {
|
||||
t.Fatalf("Send(max sized) error = %v", err)
|
||||
}
|
||||
if got := string(inner.sent[0]); got != "12345" {
|
||||
t.Fatalf("inner payload = %q, want 12345", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrafficWrapperAppliesMinimumDelay(t *testing.T) {
|
||||
inner := &trafficStubTransport{}
|
||||
tr := WithTraffic(inner, TrafficConfig{MinDelay: 2 * time.Millisecond})
|
||||
start := time.Now()
|
||||
if err := tr.Send([]byte("x")); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed < 2*time.Millisecond {
|
||||
t.Fatalf("Send() elapsed = %v, want at least 2ms", elapsed)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package transport
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -32,10 +33,17 @@ type Transport interface {
|
||||
Features() Features
|
||||
}
|
||||
|
||||
// TrafficConfig controls optional reliability-oriented send shaping.
|
||||
type TrafficConfig struct {
|
||||
MaxPayloadSize int
|
||||
MinDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
}
|
||||
|
||||
// Config holds common transport configuration.
|
||||
type Config struct {
|
||||
Carrier string
|
||||
RoomURL string
|
||||
Carrier string
|
||||
RoomURL string
|
||||
// Engine, URL, Token are forwarded to carrier.Config for the "none" auth
|
||||
// carrier (direct engine access without a service-specific auth flow).
|
||||
Engine string
|
||||
@@ -63,6 +71,7 @@ type Config struct {
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
Traffic TrafficConfig
|
||||
}
|
||||
|
||||
// Factory creates a transport instance.
|
||||
@@ -81,7 +90,11 @@ func New(ctx context.Context, name string, cfg Config) (Transport, error) {
|
||||
if !ok {
|
||||
return nil, ErrTransportNotFound
|
||||
}
|
||||
return factory(ctx, cfg)
|
||||
tr, err := factory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return WithTraffic(tr, cfg.Traffic), nil
|
||||
}
|
||||
|
||||
// Available returns a list of registered transport names.
|
||||
|
||||
Reference in New Issue
Block a user