mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
634 lines
14 KiB
Go
634 lines
14 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/openlibrecommunity/olcrtc/internal/control"
|
|
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
|
)
|
|
|
|
const testBadDuration = "nope"
|
|
|
|
func TestApplyTransportDefaults(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
in Config
|
|
want Config
|
|
}{
|
|
{
|
|
name: "vp8",
|
|
in: Config{Transport: transportVP8},
|
|
want: Config{Transport: transportVP8, VP8: VP8Config{FPS: 60, BatchSize: 64}},
|
|
},
|
|
{
|
|
name: "sei",
|
|
in: Config{Transport: transportSEI},
|
|
want: Config{
|
|
Transport: transportSEI,
|
|
SEI: SEIConfig{FPS: 60, BatchSize: 64, FragmentSize: 900, AckTimeoutMS: 2000},
|
|
},
|
|
},
|
|
{
|
|
name: "video qrcode",
|
|
in: Config{Transport: transportVideo},
|
|
want: Config{
|
|
Transport: transportVideo,
|
|
Video: VideoConfig{
|
|
Width: 1920, Height: 1080, FPS: 30, Bitrate: "2M",
|
|
HW: defaultVideoHW, QRRecovery: "low", Codec: videoCodecQRCode,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "video tile dimensions",
|
|
in: Config{Transport: transportVideo, Video: VideoConfig{Codec: videoCodecTile}},
|
|
want: Config{
|
|
Transport: transportVideo,
|
|
Video: VideoConfig{
|
|
Width: 1080, Height: 1080, FPS: 30, Bitrate: "2M",
|
|
HW: defaultVideoHW, QRRecovery: "low", Codec: videoCodecTile,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "keeps explicit values",
|
|
in: Config{
|
|
Transport: transportSEI,
|
|
SEI: SEIConfig{FPS: 10, BatchSize: 2, FragmentSize: 300, AckTimeoutMS: 1500},
|
|
},
|
|
want: Config{
|
|
Transport: transportSEI,
|
|
SEI: SEIConfig{FPS: 10, BatchSize: 2, FragmentSize: 300, AckTimeoutMS: 1500},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ApplyTransportDefaults(tt.in)
|
|
if got != tt.want {
|
|
t.Fatalf("ApplyTransportDefaults() = %+v, want %+v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestApplyLivenessDefaults(t *testing.T) {
|
|
got := ApplyLivenessDefaults(Config{})
|
|
if got.LivenessInterval != control.DefaultInterval.String() {
|
|
t.Fatalf("LivenessInterval = %q, want %q", got.LivenessInterval, control.DefaultInterval.String())
|
|
}
|
|
if got.LivenessTimeout != control.DefaultTimeout.String() {
|
|
t.Fatalf("LivenessTimeout = %q, want %q", got.LivenessTimeout, control.DefaultTimeout.String())
|
|
}
|
|
if got.LivenessFailures != control.DefaultFailures {
|
|
t.Fatalf("LivenessFailures = %d, want %d", got.LivenessFailures, control.DefaultFailures)
|
|
}
|
|
|
|
explicit := Config{LivenessInterval: "1s", LivenessTimeout: "500ms", LivenessFailures: 9}
|
|
if got := ApplyLivenessDefaults(explicit); got != explicit {
|
|
t.Fatalf("ApplyLivenessDefaults() = %+v, want %+v", got, explicit)
|
|
}
|
|
}
|
|
|
|
func TestRunWithSessionRotationRestartsAfterMaxDuration(t *testing.T) {
|
|
oldRestartDelay := sessionRestartDelay
|
|
sessionRestartDelay = time.Millisecond
|
|
t.Cleanup(func() { sessionRestartDelay = oldRestartDelay })
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
var calls atomic.Int32
|
|
err := runWithSessionRotation(ctx, 5*time.Millisecond, func(ctx context.Context) error {
|
|
if calls.Add(1) >= 2 {
|
|
cancel()
|
|
return nil
|
|
}
|
|
<-ctx.Done()
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runWithSessionRotation() error = %v", err)
|
|
}
|
|
if got := calls.Load(); got < 2 {
|
|
t.Fatalf("run calls = %d, want at least 2", got)
|
|
}
|
|
}
|
|
|
|
//nolint:maintidx // table-driven validation test naturally has many cases
|
|
func TestValidate(t *testing.T) {
|
|
RegisterDefaults()
|
|
|
|
base := Config{
|
|
Mode: modeSRV,
|
|
Transport: "datachannel",
|
|
Auth: "telemost",
|
|
RoomID: "room-1",
|
|
KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
|
|
DNSServer: "1.1.1.1:53", //nolint:goconst // test literal, repetition is intentional
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
want error
|
|
}{
|
|
{name: "valid baseline", cfg: base},
|
|
{
|
|
name: "cnc requires socks host and port",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSHost = "127.0.0.1"
|
|
cfg.SOCKSPort = 1080
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "missing mode",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = ""
|
|
return cfg
|
|
}(),
|
|
want: ErrModeRequired,
|
|
},
|
|
{
|
|
name: "unsupported carrier",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Auth = "unknown" //nolint:goconst // test literal, repetition is intentional
|
|
return cfg
|
|
}(),
|
|
want: ErrUnsupportedCarrier,
|
|
},
|
|
{
|
|
name: "unsupported transport",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "unknown"
|
|
return cfg
|
|
}(),
|
|
want: ErrUnsupportedTransport,
|
|
},
|
|
{
|
|
name: "room id required",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.RoomID = ""
|
|
return cfg
|
|
}(),
|
|
want: ErrRoomIDRequired,
|
|
},
|
|
{
|
|
name: "key required",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.KeyHex = ""
|
|
return cfg
|
|
}(),
|
|
want: ErrKeyRequired,
|
|
},
|
|
{
|
|
name: "dns server required",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.DNSServer = ""
|
|
return cfg
|
|
}(),
|
|
want: ErrDNSServerRequired,
|
|
},
|
|
{
|
|
name: "videochannel requires dimensions and bitrate settings",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel" //nolint:goconst // test literal, repetition is intentional
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoWidthRequired,
|
|
},
|
|
{
|
|
name: "videochannel rejects invalid codec",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
cfg.Video.Height = 480
|
|
cfg.Video.FPS = 30
|
|
cfg.Video.Bitrate = "1M"
|
|
cfg.Video.HW = defaultVideoHW
|
|
cfg.Video.Codec = "bogus"
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoCodecInvalid,
|
|
},
|
|
{
|
|
name: "videochannel requires height",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoHeightRequired,
|
|
},
|
|
{
|
|
name: "videochannel requires fps",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
cfg.Video.Height = 480
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoFPSRequired,
|
|
},
|
|
{
|
|
name: "videochannel requires bitrate",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
cfg.Video.Height = 480
|
|
cfg.Video.FPS = 30
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoBitrateRequired,
|
|
},
|
|
{
|
|
name: "videochannel requires hw",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
cfg.Video.Height = 480
|
|
cfg.Video.FPS = 30
|
|
cfg.Video.Bitrate = "1M"
|
|
return cfg
|
|
}(),
|
|
want: ErrVideoHWRequired,
|
|
},
|
|
{
|
|
name: "tile codec requires square 1080 dimensions",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 640
|
|
cfg.Video.Height = 480
|
|
cfg.Video.FPS = 30
|
|
cfg.Video.Bitrate = "1M"
|
|
cfg.Video.HW = defaultVideoHW
|
|
cfg.Video.Codec = "tile"
|
|
return cfg
|
|
}(),
|
|
want: ErrTileCodecDimensions,
|
|
},
|
|
{
|
|
name: "videochannel valid",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "videochannel"
|
|
cfg.Video.Width = 1080
|
|
cfg.Video.Height = 1080
|
|
cfg.Video.FPS = 30
|
|
cfg.Video.Bitrate = "1M"
|
|
cfg.Video.HW = defaultVideoHW
|
|
cfg.Video.Codec = "tile"
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "vp8channel requires fps",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "vp8channel" //nolint:goconst // test literal, repetition is intentional
|
|
return cfg
|
|
}(),
|
|
want: ErrVP8FPSRequired,
|
|
},
|
|
{
|
|
name: "vp8channel requires batch size",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "vp8channel"
|
|
cfg.VP8.FPS = 25
|
|
return cfg
|
|
}(),
|
|
want: ErrVP8BatchSizeRequired,
|
|
},
|
|
{
|
|
name: "vp8channel valid",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "vp8channel"
|
|
cfg.VP8.FPS = 25
|
|
cfg.VP8.BatchSize = 16
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "seichannel requires fps",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "seichannel" //nolint:goconst // test literal, repetition is intentional
|
|
return cfg
|
|
}(),
|
|
want: ErrSEIFPSRequired,
|
|
},
|
|
{
|
|
name: "seichannel requires batch size",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "seichannel"
|
|
cfg.SEI.FPS = 20
|
|
return cfg
|
|
}(),
|
|
want: ErrSEIBatchSizeRequired,
|
|
},
|
|
{
|
|
name: "seichannel requires fragment size",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "seichannel"
|
|
cfg.SEI.FPS = 20
|
|
cfg.SEI.BatchSize = 1
|
|
return cfg
|
|
}(),
|
|
want: ErrSEIFragmentSizeRequired,
|
|
},
|
|
{
|
|
name: "seichannel requires ack timeout",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "seichannel"
|
|
cfg.SEI.FPS = 20
|
|
cfg.SEI.BatchSize = 1
|
|
cfg.SEI.FragmentSize = 900
|
|
return cfg
|
|
}(),
|
|
want: ErrSEIAckTimeoutRequired,
|
|
},
|
|
{
|
|
name: "seichannel valid",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Transport = "seichannel"
|
|
cfg.SEI.FPS = 20
|
|
cfg.SEI.BatchSize = 1
|
|
cfg.SEI.FragmentSize = 900
|
|
cfg.SEI.AckTimeoutMS = 3000
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "cnc requires socks host",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSPort = 1080
|
|
return cfg
|
|
}(),
|
|
want: ErrSOCKSHostRequired,
|
|
},
|
|
{
|
|
name: "cnc requires socks port",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSHost = "127.0.0.1"
|
|
return cfg
|
|
}(),
|
|
want: ErrSOCKSPortRequired,
|
|
},
|
|
{
|
|
name: "cnc rejects unauthenticated wildcard socks bind",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSHost = "0.0.0.0"
|
|
cfg.SOCKSPort = 1080
|
|
return cfg
|
|
}(),
|
|
want: ErrSOCKSAuthRequired,
|
|
},
|
|
{
|
|
name: "cnc allows authenticated wildcard socks bind",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSHost = "0.0.0.0"
|
|
cfg.SOCKSPort = 1080
|
|
cfg.SOCKSUser = "user"
|
|
cfg.SOCKSPass = "pass"
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "cnc allows localhost socks bind without auth",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.Mode = modeCNC
|
|
cfg.SOCKSHost = "localhost"
|
|
cfg.SOCKSPort = 1080
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "liveness rejects bad interval",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.LivenessInterval = testBadDuration
|
|
return cfg
|
|
}(),
|
|
want: ErrLivenessIntervalInvalid,
|
|
},
|
|
{
|
|
name: "liveness rejects zero timeout",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.LivenessTimeout = "0s"
|
|
return cfg
|
|
}(),
|
|
want: ErrLivenessTimeoutInvalid,
|
|
},
|
|
{
|
|
name: "liveness rejects negative failures",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.LivenessFailures = -1
|
|
return cfg
|
|
}(),
|
|
want: ErrLivenessFailuresInvalid,
|
|
},
|
|
{
|
|
name: "lifecycle accepts max session duration",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.MaxSessionDuration = "1h"
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "lifecycle rejects bad max session duration",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.MaxSessionDuration = testBadDuration
|
|
return cfg
|
|
}(),
|
|
want: ErrLifecycleMaxSessionDurationInvalid,
|
|
},
|
|
{
|
|
name: "lifecycle rejects zero max session duration",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.MaxSessionDuration = "0s"
|
|
return cfg
|
|
}(),
|
|
want: ErrLifecycleMaxSessionDurationInvalid,
|
|
},
|
|
{
|
|
name: "traffic accepts shaping",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMaxPayloadSize = 4096
|
|
cfg.TrafficMinDelay = "5ms"
|
|
cfg.TrafficMaxDelay = "30ms"
|
|
return cfg
|
|
}(),
|
|
},
|
|
{
|
|
name: "traffic rejects negative max payload",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMaxPayloadSize = -1
|
|
return cfg
|
|
}(),
|
|
want: ErrTrafficMaxPayloadSizeInvalid,
|
|
},
|
|
{
|
|
name: "traffic rejects payload too small for encrypted smux frame",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMaxPayloadSize = runtime.MinSmuxWirePayload - 1
|
|
return cfg
|
|
}(),
|
|
want: ErrTrafficMaxPayloadSizeInvalid,
|
|
},
|
|
{
|
|
name: "traffic rejects bad min delay",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMinDelay = testBadDuration
|
|
return cfg
|
|
}(),
|
|
want: ErrTrafficMinDelayInvalid,
|
|
},
|
|
{
|
|
name: "traffic rejects negative max delay",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMaxDelay = "-1ms"
|
|
return cfg
|
|
}(),
|
|
want: ErrTrafficMaxDelayInvalid,
|
|
},
|
|
{
|
|
name: "traffic rejects max delay below min delay",
|
|
cfg: func() Config {
|
|
cfg := base
|
|
cfg.TrafficMinDelay = "30ms"
|
|
cfg.TrafficMaxDelay = "5ms"
|
|
return cfg
|
|
}(),
|
|
want: ErrTrafficMaxDelayInvalid,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := Validate(tt.cfg)
|
|
if tt.want == nil {
|
|
if err != nil {
|
|
t.Fatalf("Validate() error = %v", err)
|
|
}
|
|
return
|
|
}
|
|
if !errors.Is(err, tt.want) {
|
|
t.Fatalf("Validate() error = %v, want %v", err, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
const testAuthWBStream = "wbstream"
|
|
|
|
func TestValidateGen(t *testing.T) {
|
|
RegisterDefaults()
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
want error
|
|
}{
|
|
{
|
|
name: "wbstream room generation unsupported",
|
|
cfg: Config{Auth: testAuthWBStream, DNSServer: "1.1.1.1:53", Amount: 3},
|
|
want: ErrUnsupportedCarrier,
|
|
},
|
|
{
|
|
name: "missing auth",
|
|
cfg: Config{DNSServer: "1.1.1.1:53", Amount: 1},
|
|
want: ErrAuthRequired,
|
|
},
|
|
{
|
|
name: "unsupported auth",
|
|
cfg: Config{Auth: "unknown", DNSServer: "1.1.1.1:53", Amount: 1},
|
|
want: ErrUnsupportedCarrier,
|
|
},
|
|
{
|
|
name: "missing dns",
|
|
cfg: Config{Auth: testAuthWBStream, Amount: 1},
|
|
want: ErrDNSServerRequired,
|
|
},
|
|
{
|
|
name: "amount zero",
|
|
cfg: Config{Auth: testAuthWBStream, DNSServer: "1.1.1.1:53", Amount: 0},
|
|
want: ErrAmountRequired,
|
|
},
|
|
{
|
|
name: "amount negative",
|
|
cfg: Config{Auth: testAuthWBStream, DNSServer: "1.1.1.1:53", Amount: -1},
|
|
want: ErrAmountRequired,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := ValidateGen(tt.cfg)
|
|
if tt.want == nil {
|
|
if err != nil {
|
|
t.Fatalf("ValidateGen() error = %v", err)
|
|
}
|
|
return
|
|
}
|
|
if !errors.Is(err, tt.want) {
|
|
t.Fatalf("ValidateGen() error = %v, want %v", err, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenUnsupportedAuth(t *testing.T) {
|
|
RegisterDefaults()
|
|
cfg := Config{Auth: "telemost", DNSServer: "1.1.1.1:53", Amount: 1}
|
|
err := Gen(context.Background(), cfg, func(string) {})
|
|
if !errors.Is(err, ErrUnsupportedCarrier) {
|
|
t.Fatalf("Gen(telemost) error = %v, want ErrUnsupportedCarrier", err)
|
|
}
|
|
}
|