feat(test): init base test

This commit is contained in:
zarazaex69
2026-05-06 21:44:41 +03:00
parent 1b01c8b01b
commit aa49808e68
37 changed files with 4290 additions and 9 deletions

2
.gitignore vendored
View File

@@ -247,3 +247,5 @@ build/
GEMINI.md
code/package-lock.json
olcrtc
!cmd/olcrtc/
!cmd/olcrtc/main_test.go

137
cmd/olcrtc/main_test.go Normal file
View File

@@ -0,0 +1,137 @@
package main
import (
"errors"
"os"
"path/filepath"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
func TestToSessionConfigAndFirstNonEmpty(t *testing.T) {
cfg := config{
mode: "cnc",
link: "direct",
transport: "vp8channel",
provider: "jazz",
roomID: "room",
clientID: "client",
keyHex: "key",
socksHost: "127.0.0.1",
socksPort: 1080,
dnsServer: "1.1.1.1:53",
socksProxyAddr: "proxy",
socksProxyPort: 1081,
videoWidth: 640,
videoHeight: 480,
videoFPS: 30,
videoBitrate: "1M",
videoHW: "none",
videoQRSize: 4,
videoQRRecovery: "low",
videoCodec: "qrcode",
videoTileModule: 4,
videoTileRS: 20,
vp8FPS: 25,
vp8BatchSize: 8,
}
got := toSessionConfig(cfg)
if got.Mode != cfg.mode || got.Carrier != "jazz" || got.SOCKSPort != cfg.socksPort ||
got.VideoTileRS != cfg.videoTileRS || got.VP8BatchSize != cfg.vp8BatchSize {
t.Fatalf("toSessionConfig() = %+v", got)
}
cfg.carrier = "telemost"
got = toSessionConfig(cfg)
if got.Carrier != "telemost" {
t.Fatalf("carrier precedence = %q, want telemost", got.Carrier)
}
if got := firstNonEmpty("", "", "x", "y"); got != "x" {
t.Fatalf("firstNonEmpty() = %q, want x", got)
}
if got := firstNonEmpty("", ""); got != "" {
t.Fatalf("firstNonEmpty(empty) = %q, want empty", got)
}
}
func TestConfigureLogging(t *testing.T) {
logger.SetVerbose(false)
configureLogging(true)
if !logger.IsVerbose() {
t.Fatal("configureLogging(true) did not enable verbose logging")
}
logger.SetVerbose(false)
configureLogging(false)
if logger.IsVerbose() {
t.Fatal("configureLogging(false) enabled verbose logging")
}
}
func TestResolveDataDir(t *testing.T) {
abs := filepath.Join(t.TempDir(), "data")
got, err := resolveDataDir(abs)
if err != nil {
t.Fatalf("resolveDataDir(abs) error = %v", err)
}
if got != abs {
t.Fatalf("resolveDataDir(abs) = %q, want %q", got, abs)
}
got, err = resolveDataDir("data")
if err != nil {
t.Fatalf("resolveDataDir(rel) error = %v", err)
}
if filepath.Base(got) != "data" || !filepath.IsAbs(got) {
t.Fatalf("resolveDataDir(rel) = %q, want absolute path ending in data", got)
}
}
func TestLoadNames(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil {
t.Fatalf("WriteFile(names) error = %v", err)
}
if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil {
t.Fatalf("WriteFile(surnames) error = %v", err)
}
if err := loadNames(dir); err != nil {
t.Fatalf("loadNames() error = %v", err)
}
}
func TestWaitForShutdown(t *testing.T) {
errCh := make(chan error, 1)
errCh <- nil
if err := waitForShutdown(errCh); err != nil {
t.Fatalf("waitForShutdown(nil) error = %v", err)
}
want := errors.New("boom")
errCh = make(chan error, 1)
errCh <- want
if err := waitForShutdown(errCh); !errors.Is(err, want) {
t.Fatalf("waitForShutdown(error) = %v, want %v", err, want)
}
}
func TestValidateConfigAliasStillValidates(t *testing.T) {
session.RegisterDefaults()
cfg := config{
mode: "srv",
link: "direct",
transport: "datachannel",
provider: "jazz",
clientID: "client",
keyHex: "key",
dnsServer: "1.1.1.1:53",
videoCodec: "qrcode",
}
if err := session.Validate(toSessionConfig(cfg)); err != nil {
t.Fatalf("Validate(toSessionConfig(alias)) error = %v", err)
}
}

View File

@@ -0,0 +1,303 @@
package session
import (
"errors"
"testing"
)
func TestValidate(t *testing.T) {
RegisterDefaults()
base := Config{
Mode: modeSRV,
Link: "direct",
Transport: "datachannel",
Carrier: "telemost",
RoomID: "room-1",
ClientID: "client-1",
KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
DNSServer: "1.1.1.1:53",
}
tests := []struct {
name string
cfg Config
want error
}{
{name: "valid baseline", cfg: base},
{
name: "jazz allows empty room id",
cfg: func() Config {
cfg := base
cfg.Carrier = "jazz"
cfg.RoomID = ""
return cfg
}(),
},
{
name: "cnc requires socks host and port",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "127.0.0.1"
cfg.SOCKSPort = 1080
return cfg
}(),
},
{
name: "missing mode",
cfg: func() Config {
cfg := base
cfg.Mode = ""
return cfg
}(),
want: ErrModeRequired,
},
{
name: "unsupported carrier",
cfg: func() Config {
cfg := base
cfg.Carrier = "unknown"
return cfg
}(),
want: ErrUnsupportedCarrier,
},
{
name: "unsupported link",
cfg: func() Config {
cfg := base
cfg.Link = "unknown"
return cfg
}(),
want: ErrUnsupportedLink,
},
{
name: "unsupported transport",
cfg: func() Config {
cfg := base
cfg.Transport = "unknown"
return cfg
}(),
want: ErrUnsupportedTransport,
},
{
name: "room id required for non jazz",
cfg: func() Config {
cfg := base
cfg.RoomID = ""
return cfg
}(),
want: ErrRoomIDRequired,
},
{
name: "client id required",
cfg: func() Config {
cfg := base
cfg.ClientID = ""
return cfg
}(),
want: ErrClientIDRequired,
},
{
name: "key required",
cfg: func() Config {
cfg := base
cfg.KeyHex = ""
return cfg
}(),
want: ErrKeyRequired,
},
{
name: "dns server required",
cfg: func() Config {
cfg := base
cfg.DNSServer = ""
return cfg
}(),
want: ErrDNSServerRequired,
},
{
name: "videochannel requires dimensions and bitrate settings",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
return cfg
}(),
want: ErrVideoWidthRequired,
},
{
name: "videochannel rejects invalid codec",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
cfg.VideoHeight = 480
cfg.VideoFPS = 30
cfg.VideoBitrate = "1M"
cfg.VideoHW = "none"
cfg.VideoCodec = "bogus"
return cfg
}(),
want: ErrVideoCodecInvalid,
},
{
name: "videochannel requires height",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
return cfg
}(),
want: ErrVideoHeightRequired,
},
{
name: "videochannel requires fps",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
cfg.VideoHeight = 480
return cfg
}(),
want: ErrVideoFPSRequired,
},
{
name: "videochannel requires bitrate",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
cfg.VideoHeight = 480
cfg.VideoFPS = 30
return cfg
}(),
want: ErrVideoBitrateRequired,
},
{
name: "videochannel requires hw",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
cfg.VideoHeight = 480
cfg.VideoFPS = 30
cfg.VideoBitrate = "1M"
return cfg
}(),
want: ErrVideoHWRequired,
},
{
name: "tile codec requires square 1080 dimensions",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 640
cfg.VideoHeight = 480
cfg.VideoFPS = 30
cfg.VideoBitrate = "1M"
cfg.VideoHW = "none"
cfg.VideoCodec = "tile"
return cfg
}(),
want: ErrTileCodecDimensions,
},
{
name: "videochannel valid",
cfg: func() Config {
cfg := base
cfg.Transport = "videochannel"
cfg.VideoWidth = 1080
cfg.VideoHeight = 1080
cfg.VideoFPS = 30
cfg.VideoBitrate = "1M"
cfg.VideoHW = "none"
cfg.VideoCodec = "tile"
return cfg
}(),
},
{
name: "vp8channel requires fps",
cfg: func() Config {
cfg := base
cfg.Transport = "vp8channel"
return cfg
}(),
want: ErrVP8FPSRequired,
},
{
name: "vp8channel requires batch size",
cfg: func() Config {
cfg := base
cfg.Transport = "vp8channel"
cfg.VP8FPS = 25
return cfg
}(),
want: ErrVP8BatchSizeRequired,
},
{
name: "vp8channel valid",
cfg: func() Config {
cfg := base
cfg.Transport = "vp8channel"
cfg.VP8FPS = 25
cfg.VP8BatchSize = 16
return cfg
}(),
},
{
name: "cnc requires socks host",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSPort = 1080
return cfg
}(),
want: ErrSOCKSHostRequired,
},
{
name: "cnc requires socks port",
cfg: func() Config {
cfg := base
cfg.Mode = modeCNC
cfg.SOCKSHost = "127.0.0.1"
return cfg
}(),
want: ErrSOCKSPortRequired,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Validate(tt.cfg)
if tt.want == nil {
if err != nil {
t.Fatalf("Validate() error = %v", err)
}
return
}
if !errors.Is(err, tt.want) {
t.Fatalf("Validate() error = %v, want %v", err, tt.want)
}
})
}
}
func TestBuildRoomURL(t *testing.T) {
tests := []struct {
carrier string
roomID string
want string
}{
{carrier: "telemost", roomID: "abc", want: "https://telemost.yandex.ru/j/abc"},
{carrier: "jazz", roomID: "", want: "any"},
{carrier: "jazz", roomID: "room", want: "room"},
{carrier: "wbstream", roomID: "wb", want: "wb"},
{carrier: "other", roomID: "raw", want: "raw"},
}
for _, tt := range tests {
if got := buildRoomURL(tt.carrier, tt.roomID); got != tt.want {
t.Fatalf("buildRoomURL(%q, %q) = %q, want %q", tt.carrier, tt.roomID, got, tt.want)
}
}
}

View File

@@ -0,0 +1,18 @@
package builtin
import (
"slices"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/carrier"
)
func TestRegister(t *testing.T) {
Register()
available := carrier.Available()
for _, want := range []string{"jazz", "telemost", "wbstream"} {
if !slices.Contains(available, want) {
t.Fatalf("Available() = %v, missing %q", available, want)
}
}
}

View File

@@ -0,0 +1,251 @@
package carrier
import (
"context"
"errors"
"reflect"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
)
type stubProvider struct {
connectErr error
sendErr error
closeErr error
canSend bool
reconnectCallback func(*webrtc.DataChannel)
shouldReconnect func() bool
endedCallback func(string)
watchCalled bool
addTrackErr error
trackHandlerCalled bool
}
func (s *stubProvider) Connect(context.Context) error { return s.connectErr }
func (s *stubProvider) Send([]byte) error { return s.sendErr }
func (s *stubProvider) Close() error { return s.closeErr }
func (s *stubProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.reconnectCallback = cb }
func (s *stubProvider) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn }
func (s *stubProvider) SetEndedCallback(cb func(string)) { s.endedCallback = cb }
func (s *stubProvider) WatchConnection(context.Context) { s.watchCalled = true }
func (s *stubProvider) CanSend() bool { return s.canSend }
func (s *stubProvider) GetSendQueue() chan []byte { return nil }
func (s *stubProvider) GetBufferedAmount() uint64 { return 0 }
func (s *stubProvider) AddVideoTrack(webrtc.TrackLocal) error { return s.addTrackErr }
func (s *stubProvider) SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) {
s.trackHandlerCalled = true
}
type plainProvider struct {
connectErr error
sendErr error
closeErr error
canSend bool
reconnectCallback func(*webrtc.DataChannel)
shouldReconnect func() bool
endedCallback func(string)
watchCalled bool
}
func (p *plainProvider) Connect(context.Context) error { return p.connectErr }
func (p *plainProvider) Send([]byte) error { return p.sendErr }
func (p *plainProvider) Close() error { return p.closeErr }
func (p *plainProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) { p.reconnectCallback = cb }
func (p *plainProvider) SetShouldReconnect(fn func() bool) { p.shouldReconnect = fn }
func (p *plainProvider) SetEndedCallback(cb func(string)) { p.endedCallback = cb }
func (p *plainProvider) WatchConnection(context.Context) { p.watchCalled = true }
func (p *plainProvider) CanSend() bool { return p.canSend }
func (p *plainProvider) GetSendQueue() chan []byte { return nil }
func (p *plainProvider) GetBufferedAmount() uint64 { return 0 }
func snapshotCarrierRegistry() map[string]Factory {
out := make(map[string]Factory, len(registry))
for k, v := range registry {
out[k] = v
}
return out
}
func restoreCarrierRegistry(src map[string]Factory) {
registry = make(map[string]Factory, len(src))
for k, v := range src {
registry[k] = v
}
}
func TestRegisterLegacyAndAvailable(t *testing.T) {
old := snapshotCarrierRegistry()
t.Cleanup(func() { restoreCarrierRegistry(old) })
RegisterLegacy("legacy-test", func(_ context.Context, cfg provider.Config) (provider.Provider, error) {
if cfg.Name != "peer" {
t.Fatalf("provider config name = %q, want peer", cfg.Name)
}
return &stubProvider{canSend: true}, nil
})
sess, err := New(context.Background(), "legacy-test", Config{Name: "peer"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
caps := sess.Capabilities()
if !caps.ByteStream || !caps.VideoTrack {
t.Fatalf("Capabilities() = %+v, want byte and video true", caps)
}
if !reflect.DeepEqual(Available(), []string{"legacy-test"}) {
t.Fatalf("Available() = %#v, want %#v", Available(), []string{"legacy-test"})
}
}
func TestNewReturnsErrCarrierNotFound(t *testing.T) {
old := snapshotCarrierRegistry()
t.Cleanup(func() { restoreCarrierRegistry(old) })
registry = map[string]Factory{}
_, err := New(context.Background(), "missing", Config{})
if !errors.Is(err, ErrCarrierNotFound) {
t.Fatalf("New() error = %v, want %v", err, ErrCarrierNotFound)
}
}
func TestLegacySessionOpenVideoTrackUnsupported(t *testing.T) {
sess := &legacySession{provider: &plainProvider{}}
caps := sess.Capabilities()
if !caps.ByteStream || caps.VideoTrack {
t.Fatalf("Capabilities() = %+v, want byte true and video false", caps)
}
_, err := sess.OpenVideoTrack()
if !errors.Is(err, ErrVideoTrackUnsupported) {
t.Fatalf("OpenVideoTrack() error = %v, want %v", err, ErrVideoTrackUnsupported)
}
}
func TestLegacyByteStreamWrapsProviderAndCallbacks(t *testing.T) {
prov := &stubProvider{canSend: true}
stream := &legacyByteStream{provider: prov}
called := false
stream.SetReconnectCallback(func() { called = true })
if prov.reconnectCallback == nil {
t.Fatal("SetReconnectCallback() did not install provider callback")
}
prov.reconnectCallback(nil)
if !called {
t.Fatal("reconnect callback was not adapted")
}
reconnectAllowed := false
stream.SetShouldReconnect(func() bool { reconnectAllowed = true; return true })
if prov.shouldReconnect == nil || !prov.shouldReconnect() || !reconnectAllowed {
t.Fatal("SetShouldReconnect() was not forwarded")
}
ended := ""
stream.SetEndedCallback(func(reason string) { ended = reason })
if prov.endedCallback == nil {
t.Fatal("SetEndedCallback() was not forwarded")
}
prov.endedCallback("bye")
if ended != "bye" {
t.Fatalf("ended callback reason = %q, want bye", ended)
}
stream.WatchConnection(context.Background())
if !prov.watchCalled {
t.Fatal("WatchConnection() was not forwarded")
}
if !stream.CanSend() {
t.Fatal("CanSend() = false, want true")
}
}
func TestLegacyByteStreamWrapsErrors(t *testing.T) {
prov := &stubProvider{
connectErr: errors.New("connect boom"),
sendErr: errors.New("send boom"),
closeErr: errors.New("close boom"),
}
stream := &legacyByteStream{provider: prov}
if err := stream.Connect(context.Background()); err == nil || err.Error() != "connect: connect boom" {
t.Fatalf("Connect() error = %v", err)
}
if err := stream.Send([]byte("x")); err == nil || err.Error() != "send: send boom" {
t.Fatalf("Send() error = %v", err)
}
if err := stream.Close(); err == nil || err.Error() != "close: close boom" {
t.Fatalf("Close() error = %v", err)
}
}
func TestLegacySessionOpenByteStreamAndVideoTrack(t *testing.T) {
prov := &stubProvider{canSend: true}
sess := &legacySession{provider: prov}
stream, err := sess.OpenByteStream()
if err != nil {
t.Fatalf("OpenByteStream() error = %v", err)
}
if !stream.CanSend() {
t.Fatal("byte stream CanSend() = false, want true")
}
video, err := sess.OpenVideoTrack()
if err != nil {
t.Fatalf("OpenVideoTrack() error = %v", err)
}
if err := video.Connect(context.Background()); err != nil {
t.Fatalf("video Connect() error = %v", err)
}
if err := video.Close(); err != nil {
t.Fatalf("video Close() error = %v", err)
}
video.SetShouldReconnect(func() bool { return true })
video.SetEndedCallback(func(string) {})
video.WatchConnection(context.Background())
if !video.CanSend() || prov.shouldReconnect == nil || prov.endedCallback == nil || !prov.watchCalled {
t.Fatal("video adapter did not forward calls")
}
}
func TestLegacyVideoTrackWrapsOperations(t *testing.T) {
prov := &stubProvider{canSend: true, addTrackErr: errors.New("track boom")}
track := &legacyVideoTrack{provider: prov}
called := false
track.SetReconnectCallback(func() { called = true })
prov.reconnectCallback(nil)
if !called {
t.Fatal("reconnect callback was not adapted")
}
track.SetTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if !prov.trackHandlerCalled {
t.Fatal("SetTrackHandler() was not forwarded")
}
if err := track.AddTrack(nil); err == nil || err.Error() != "add track: track boom" {
t.Fatalf("AddTrack() error = %v", err)
}
}
func TestLegacyVideoTrackWrapsConnectCloseErrors(t *testing.T) {
prov := &stubProvider{
connectErr: errors.New("connect boom"),
closeErr: errors.New("close boom"),
}
track := &legacyVideoTrack{provider: prov}
if err := track.Connect(context.Background()); err == nil || err.Error() != "connect: connect boom" {
t.Fatalf("Connect() error = %v", err)
}
if err := track.Close(); err == nil || err.Error() != "close: close boom" {
t.Fatalf("Close() error = %v", err)
}
}

View File

@@ -0,0 +1,419 @@
package client
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"io"
"net"
"testing"
"time"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
)
func TestSetupCipher(t *testing.T) {
keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"
cipher, err := setupCipher(keyHex)
if err != nil {
t.Fatalf("setupCipher() error = %v", err)
}
if cipher == nil {
t.Fatal("setupCipher() returned nil cipher")
}
}
func TestSetupCipherRejectsBadInput(t *testing.T) {
if _, err := setupCipher("zz"); err == nil {
t.Fatal("setupCipher() unexpectedly succeeded for bad hex")
}
if _, err := setupCipher("00"); !errors.Is(err, ErrKeySize) {
t.Fatalf("setupCipher() error = %v, want ErrKeySize", err)
}
}
func TestSmuxConfig(t *testing.T) {
cfg := smuxConfig()
if cfg.Version != 2 || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
}
func TestSocks5Handshake(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
done <- c.socks5Handshake(server)
}()
if _, err := client.Write([]byte{5, 1, 0}); err != nil {
t.Fatalf("Write() error = %v", err)
}
resp := make([]byte, 2)
if _, err := io.ReadFull(client, resp); err != nil {
t.Fatalf("ReadFull() error = %v", err)
}
if err := <-done; err != nil {
t.Fatalf("socks5Handshake() error = %v", err)
}
if !bytes.Equal(resp, []byte{5, 0}) {
t.Fatalf("handshake response = %v, want [5 0]", resp)
}
}
func TestSocks5HandshakeRejectsVersion(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
done <- c.socks5Handshake(server)
}()
if _, err := client.Write([]byte{4, 1}); err != nil {
t.Fatalf("Write() error = %v", err)
}
if err := <-done; !errors.Is(err, ErrInvalidSOCKSVersion) {
t.Fatalf("socks5Handshake() error = %v, want %v", err, ErrInvalidSOCKSVersion)
}
}
func TestSocks5HandshakeReadMethodsError(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
done <- c.socks5Handshake(server)
}()
if _, err := client.Write([]byte{5, 2, 0}); err != nil {
t.Fatalf("Write() error = %v", err)
}
_ = client.Close()
if err := <-done; err == nil {
t.Fatal("socks5Handshake() unexpectedly succeeded")
}
}
func TestSocks5RequestIPv4(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan struct {
addr string
port int
err error
}, 1)
go func() {
addr, port, err := c.socks5Request(server)
done <- struct {
addr string
port int
err error
}{addr: addr, port: port, err: err}
}()
req := []byte{5, 1, 0, 1, 127, 0, 0, 1}
port := make([]byte, 2)
binary.BigEndian.PutUint16(port, 8080)
if _, err := client.Write(append(req, port...)); err != nil {
t.Fatalf("Write() error = %v", err)
}
res := <-done
if res.err != nil {
t.Fatalf("socks5Request() error = %v", res.err)
}
if res.addr != "127.0.0.1" || res.port != 8080 {
t.Fatalf("socks5Request() = (%q, %d), want (127.0.0.1, 8080)", res.addr, res.port)
}
}
func TestSocks5RequestDomain(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan struct {
addr string
port int
err error
}, 1)
go func() {
addr, port, err := c.socks5Request(server)
done <- struct {
addr string
port int
err error
}{addr: addr, port: port, err: err}
}()
req := []byte{5, 1, 0, 3, 11}
req = append(req, []byte("example.com")...)
port := make([]byte, 2)
binary.BigEndian.PutUint16(port, 443)
if _, err := client.Write(append(req, port...)); err != nil {
t.Fatalf("Write() error = %v", err)
}
res := <-done
if res.err != nil {
t.Fatalf("socks5Request() error = %v", res.err)
}
if res.addr != "example.com" || res.port != 443 {
t.Fatalf("socks5Request() = (%q, %d), want (example.com, 443)", res.addr, res.port)
}
}
func TestSocks5RequestRejectsCommandAndAddressType(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
_, _, err := c.socks5Request(server)
done <- err
}()
if _, err := client.Write([]byte{5, 2, 0, 1}); err != nil {
t.Fatalf("Write() error = %v", err)
}
if err := <-done; !errors.Is(err, ErrUnsupportedSOCKSCommand) {
t.Fatalf("socks5Request() error = %v, want %v", err, ErrUnsupportedSOCKSCommand)
}
server2, client2 := net.Pipe()
defer func() {
_ = server2.Close()
_ = client2.Close()
}()
done = make(chan error, 1)
go func() {
_, _, err := c.socks5Request(server2)
done <- err
}()
if _, err := client2.Write([]byte{5, 1, 0, 9}); err != nil {
t.Fatalf("Write() error = %v", err)
}
if err := <-done; !errors.Is(err, ErrUnsupportedAddressType) {
t.Fatalf("socks5Request() error = %v, want %v", err, ErrUnsupportedAddressType)
}
}
func TestSocks5RequestReadPortError(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
_, _, err := c.socks5Request(server)
done <- err
}()
if _, err := client.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1, 0}); err != nil {
t.Fatalf("Write() error = %v", err)
}
_ = client.Close()
if err := <-done; err == nil {
t.Fatal("socks5Request() unexpectedly succeeded")
}
}
func TestReplyBuffers(t *testing.T) {
if !bytes.Equal(replySuccess(), []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) {
t.Fatalf("replySuccess() = %v", replySuccess())
}
if !bytes.Equal(replyHostUnreachable(), []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) {
t.Fatalf("replyHostUnreachable() = %v", replyHostUnreachable())
}
}
func TestReadSocks5AddrReadErrors(t *testing.T) {
c := &Client{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
_, err := c.readSocks5Addr(server, 1)
done <- err
}()
time.Sleep(10 * time.Millisecond)
_ = client.Close()
if err := <-done; err == nil {
t.Fatal("readSocks5Addr() unexpectedly succeeded")
}
}
func TestSendConnectRequestOverSmux(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
done := make(chan error, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err != nil {
done <- err
return
}
defer func() { _ = stream.Close() }()
var req map[string]any
if err := json.NewDecoder(stream).Decode(&req); err != nil {
done <- err
return
}
if req["cmd"] != "connect" || req["clientId"] != "client-1" || req["addr"] != "example.com" {
done <- errors.New("unexpected connect request")
return
}
_, err = stream.Write([]byte{0x00})
done <- err
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
defer func() { _ = stream.Close() }()
c := &Client{clientID: "client-1"}
if err := c.sendConnectRequest(stream, "example.com", 443); err != nil {
t.Fatalf("sendConnectRequest() error = %v", err)
}
if err := <-done; err != nil {
t.Fatalf("server side error = %v", err)
}
}
func TestSendConnectRequestRejectsBadAck(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
go func() {
stream, err := serverSess.AcceptStream()
if err != nil {
return
}
defer func() { _ = stream.Close() }()
_, _ = io.CopyN(io.Discard, stream, 1)
_, _ = stream.Write([]byte{0x01})
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
defer func() { _ = stream.Close() }()
c := &Client{clientID: "client-1"}
if err := c.sendConnectRequest(stream, "example.com", 443); !errors.Is(err, ErrRemoteNotReady) {
t.Fatalf("sendConnectRequest() error = %v, want %v", err, ErrRemoteNotReady)
}
}
type closerLinkStub struct {
closed bool
}
func (s *closerLinkStub) Connect(context.Context) error { return nil }
func (s *closerLinkStub) Send([]byte) error { return nil }
func (s *closerLinkStub) Close() error { s.closed = true; return nil }
func (s *closerLinkStub) SetReconnectCallback(func()) {}
func (s *closerLinkStub) SetShouldReconnect(func() bool) {}
func (s *closerLinkStub) SetEndedCallback(func(string)) {}
func (s *closerLinkStub) WatchConnection(context.Context) {}
func (s *closerLinkStub) CanSend() bool { return true }
func TestOnDataWithNilConn(t *testing.T) {
c := &Client{}
c.onData([]byte("ignored"))
}
func TestShutdownClosesLinkAndConn(t *testing.T) {
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
ln := &closerLinkStub{}
c := &Client{
ln: ln,
cipher: cipher,
conn: muxconn.New(ln, cipher),
}
c.shutdown()
if !ln.closed {
t.Fatal("shutdown() did not close link")
}
}

View File

@@ -0,0 +1,50 @@
package crypto
import (
"bytes"
"errors"
"testing"
)
func TestNewCipherRejectsWrongKeySize(t *testing.T) {
_, err := NewCipher("short")
if !errors.Is(err, ErrInvalidKeySize) {
t.Fatalf("NewCipher() error = %v, want %v", err, ErrInvalidKeySize)
}
}
func TestCipherRoundTrip(t *testing.T) {
c, err := NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
plaintext := []byte("hello world")
ciphertext, err := c.Encrypt(plaintext)
if err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
if bytes.Equal(ciphertext, plaintext) {
t.Fatal("ciphertext unexpectedly matches plaintext")
}
got, err := c.Decrypt(ciphertext)
if err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Fatalf("Decrypt() = %q, want %q", got, plaintext)
}
}
func TestDecryptRejectsShortCiphertext(t *testing.T) {
c, err := NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
_, err = c.Decrypt([]byte("short"))
if !errors.Is(err, ErrCiphertextTooShort) {
t.Fatalf("Decrypt() error = %v, want %v", err, ErrCiphertextTooShort)
}
}

View File

@@ -0,0 +1,137 @@
package direct
import (
"context"
"errors"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
type stubTransport struct {
connectErr error
sendErr error
closeErr error
canSend bool
connectCalled bool
sendData []byte
watched bool
reconnectCB func()
shouldFn func() bool
endedCB func(string)
}
func (s *stubTransport) Connect(context.Context) error {
s.connectCalled = true
return s.connectErr
}
func (s *stubTransport) Send(data []byte) error {
s.sendData = append([]byte(nil), data...)
return s.sendErr
}
func (s *stubTransport) Close() error { return s.closeErr }
func (s *stubTransport) SetReconnectCallback(cb func()) {
s.reconnectCB = cb
}
func (s *stubTransport) SetShouldReconnect(fn func() bool) { s.shouldFn = fn }
func (s *stubTransport) SetEndedCallback(cb func(string)) { s.endedCB = cb }
func (s *stubTransport) WatchConnection(context.Context) { s.watched = true }
func (s *stubTransport) CanSend() bool { return s.canSend }
func (s *stubTransport) Features() transport.Features { return transport.Features{} }
func TestNewForwardsConfigAndMethods(t *testing.T) {
name := "direct-test-forward"
var seen transport.Config
tr := &stubTransport{canSend: true}
transport.Register(name, func(_ context.Context, cfg transport.Config) (transport.Transport, error) {
seen = cfg
return tr, nil
})
ln, err := New(context.Background(), link.Config{
Transport: name,
Carrier: "carrier",
RoomURL: "room",
ClientID: "client",
Name: "peer",
DNSServer: "1.1.1.1:53",
ProxyAddr: "127.0.0.1",
ProxyPort: 1080,
VideoWidth: 640,
VideoHeight: 480,
VideoFPS: 30,
VideoBitrate: "1M",
VideoHW: "none",
VideoQRSize: 4,
VideoQRRecovery: "low",
VideoCodec: "qrcode",
VideoTileModule: 3,
VideoTileRS: 20,
VP8FPS: 25,
VP8BatchSize: 8,
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if seen.ClientID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 {
t.Fatalf("forwarded config = %+v", seen)
}
if err := ln.Connect(context.Background()); err != nil {
t.Fatalf("Connect() error = %v", err)
}
if !tr.connectCalled {
t.Fatal("Connect() was not forwarded")
}
if err := ln.Send([]byte("payload")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if string(tr.sendData) != "payload" {
t.Fatalf("Send() forwarded %q, want payload", tr.sendData)
}
ln.SetReconnectCallback(func() {})
ln.SetShouldReconnect(func() bool { return true })
ln.SetEndedCallback(func(string) {})
ln.WatchConnection(context.Background())
if tr.reconnectCB == nil || tr.shouldFn == nil || tr.endedCB == nil || !tr.watched {
t.Fatal("callbacks/watch were not forwarded")
}
if !ln.CanSend() {
t.Fatal("CanSend() = false, want true")
}
}
func TestNewWrapsFactoryError(t *testing.T) {
name := "direct-test-error"
transport.Register(name, func(context.Context, transport.Config) (transport.Transport, error) {
return nil, errors.New("boom")
})
_, err := New(context.Background(), link.Config{Transport: name})
if err == nil || err.Error() != "create transport for direct link: boom" {
t.Fatalf("New() error = %v", err)
}
}
func TestDirectLinkWrapsTransportErrors(t *testing.T) {
ln := &directLink{transport: &stubTransport{
connectErr: errors.New("connect boom"),
sendErr: errors.New("send boom"),
closeErr: errors.New("close boom"),
}}
if err := ln.Connect(context.Background()); err == nil || err.Error() != "transport connect: connect boom" {
t.Fatalf("Connect() error = %v", err)
}
if err := ln.Send([]byte("x")); err == nil || err.Error() != "transport send: send boom" {
t.Fatalf("Send() error = %v", err)
}
if err := ln.Close(); err == nil || err.Error() != "transport close: close boom" {
t.Fatalf("Close() error = %v", err)
}
}

View File

@@ -0,0 +1,71 @@
package link
import (
"context"
"errors"
"reflect"
"testing"
)
type stubLink struct{}
func (s *stubLink) Connect(context.Context) error { return nil }
func (s *stubLink) Send([]byte) error { return nil }
func (s *stubLink) Close() error { return nil }
func (s *stubLink) SetReconnectCallback(func()) {}
func (s *stubLink) SetShouldReconnect(func() bool) {}
func (s *stubLink) SetEndedCallback(func(string)) {}
func (s *stubLink) WatchConnection(context.Context) {}
func (s *stubLink) CanSend() bool { return true }
func snapshotLinkRegistry() map[string]Factory {
out := make(map[string]Factory, len(registry))
for k, v := range registry {
out[k] = v
}
return out
}
func restoreLinkRegistry(src map[string]Factory) {
registry = make(map[string]Factory, len(src))
for k, v := range src {
registry[k] = v
}
}
func TestNewAndAvailable(t *testing.T) {
old := snapshotLinkRegistry()
t.Cleanup(func() { restoreLinkRegistry(old) })
called := false
Register("test-link", func(_ context.Context, cfg Config) (Link, error) {
called = cfg.ClientID == "client-1"
return &stubLink{}, nil
})
got, err := New(context.Background(), "test-link", Config{ClientID: "client-1"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if !called {
t.Fatal("factory did not receive config")
}
if _, ok := got.(*stubLink); !ok {
t.Fatalf("New() returned %T, want *stubLink", got)
}
if !reflect.DeepEqual(Available(), []string{"test-link"}) {
t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-link"})
}
}
func TestNewReturnsErrLinkNotFound(t *testing.T) {
old := snapshotLinkRegistry()
t.Cleanup(func() { restoreLinkRegistry(old) })
registry = map[string]Factory{}
_, err := New(context.Background(), "missing", Config{})
if !errors.Is(err, ErrLinkNotFound) {
t.Fatalf("New() error = %v, want %v", err, ErrLinkNotFound)
}
}

View File

@@ -0,0 +1,72 @@
package logger
import (
"bytes"
"log"
"strings"
"testing"
)
func captureLogs(t *testing.T) *bytes.Buffer {
t.Helper()
var buf bytes.Buffer
oldWriter := log.Writer()
oldFlags := log.Flags()
log.SetOutput(&buf)
log.SetFlags(0)
t.Cleanup(func() {
log.SetOutput(oldWriter)
log.SetFlags(oldFlags)
SetVerbose(false)
})
return &buf
}
func TestVerboseFlag(t *testing.T) {
SetVerbose(true)
if !IsVerbose() {
t.Fatal("IsVerbose() = false, want true")
}
SetVerbose(false)
if IsVerbose() {
t.Fatal("IsVerbose() = true, want false")
}
}
func TestLoggingFunctions(t *testing.T) {
buf := captureLogs(t)
Info("info")
Infof("%s", "infof")
Warn("warn")
Warnf("%s", "warnf")
Error("error")
Errorf("%s", "errorf")
got := buf.String()
for _, want := range []string{"info", "infof", "warn", "warnf", "error", "errorf"} {
if !strings.Contains(got, want) {
t.Fatalf("log output %q does not contain %q", got, want)
}
}
}
func TestVerboseAndDebugLogging(t *testing.T) {
buf := captureLogs(t)
Verbosef("%s", "hidden")
Debugf("%s", "hidden-debug")
if got := buf.String(); got != "" {
t.Fatalf("unexpected log output when verbose disabled: %q", got)
}
SetVerbose(true)
Verbosef("%s", "visible")
Debugf("%s", "visible-debug")
got := buf.String()
for _, want := range []string{"visible", "visible-debug"} {
if !strings.Contains(got, want) {
t.Fatalf("log output %q does not contain %q", got, want)
}
}
}

View File

@@ -0,0 +1,198 @@
package muxconn
import (
"bytes"
"context"
"errors"
"io"
"sync"
"testing"
"time"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
)
type stubLink struct {
mu sync.Mutex
canSend bool
sendErr error
sent [][]byte
canSendFn func() bool
}
func (s *stubLink) Connect(context.Context) error { return nil }
func (s *stubLink) Close() error { return nil }
func (s *stubLink) SetReconnectCallback(func()) {}
func (s *stubLink) SetShouldReconnect(func() bool) {}
func (s *stubLink) SetEndedCallback(func(string)) {}
func (s *stubLink) WatchConnection(context.Context) {}
func (s *stubLink) Send(data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sent = append(s.sent, append([]byte(nil), data...))
return s.sendErr
}
func (s *stubLink) CanSend() bool {
if s.canSendFn != nil {
return s.canSendFn()
}
s.mu.Lock()
defer s.mu.Unlock()
return s.canSend
}
func newTestCipher(t *testing.T) *cryptopkg.Cipher {
t.Helper()
c, err := cryptopkg.NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
return c
}
func TestPushAndReadRoundTrip(t *testing.T) {
cipher := newTestCipher(t)
conn := New(&stubLink{canSend: true}, cipher)
msg1, err := cipher.Encrypt([]byte("hello "))
if err != nil {
t.Fatalf("Encrypt(msg1) error = %v", err)
}
msg2, err := cipher.Encrypt([]byte("world"))
if err != nil {
t.Fatalf("Encrypt(msg2) error = %v", err)
}
conn.Push(msg1)
conn.Push(msg2)
buf := make([]byte, 11)
n, err := conn.Read(buf)
if err != nil {
t.Fatalf("Read() error = %v", err)
}
if got := string(buf[:n]); got != "hello world" {
t.Fatalf("Read() = %q, want %q", got, "hello world")
}
}
func TestPushIgnoresInvalidCiphertext(t *testing.T) {
cipher := newTestCipher(t)
conn := New(&stubLink{canSend: true}, cipher)
conn.Push([]byte("bad"))
if err := conn.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
buf := make([]byte, 8)
n, err := conn.Read(buf)
if !errors.Is(err, io.EOF) || n != 0 {
t.Fatalf("Read() = (%d, %v), want (0, EOF)", n, err)
}
}
func TestWriteEncryptsAndSends(t *testing.T) {
cipher := newTestCipher(t)
ln := &stubLink{canSend: true}
conn := New(ln, cipher)
n, err := conn.Write([]byte("payload"))
if err != nil {
t.Fatalf("Write() error = %v", err)
}
if n != len("payload") {
t.Fatalf("Write() n = %d, want %d", n, len("payload"))
}
if len(ln.sent) != 1 {
t.Fatalf("sent packets = %d, want 1", len(ln.sent))
}
got, err := cipher.Decrypt(ln.sent[0])
if err != nil {
t.Fatalf("Decrypt(sent) error = %v", err)
}
if !bytes.Equal(got, []byte("payload")) {
t.Fatalf("decrypted payload = %q, want %q", got, "payload")
}
}
func TestWriteWaitsForCanSend(t *testing.T) {
cipher := newTestCipher(t)
start := time.Now()
readyAt := start.Add(15 * time.Millisecond)
ln := &stubLink{
canSendFn: func() bool {
return time.Now().After(readyAt)
},
}
conn := New(ln, cipher)
if _, err := conn.Write([]byte("payload")); err != nil {
t.Fatalf("Write() error = %v", err)
}
if len(ln.sent) != 1 {
t.Fatalf("sent packets = %d, want 1", len(ln.sent))
}
}
func TestWriteReturnsErrClosedWhileWaiting(t *testing.T) {
cipher := newTestCipher(t)
conn := New(&stubLink{canSend: false}, cipher)
done := make(chan error, 1)
go func() {
_, err := conn.Write([]byte("payload"))
done <- err
}()
time.Sleep(10 * time.Millisecond)
if err := conn.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
select {
case err := <-done:
if !errors.Is(err, ErrClosed) {
t.Fatalf("Write() error = %v, want %v", err, ErrClosed)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Write() did not unblock after Close")
}
}
func TestWriteWrapsSendError(t *testing.T) {
cipher := newTestCipher(t)
conn := New(&stubLink{canSend: true, sendErr: errors.New("boom")}, cipher)
_, err := conn.Write([]byte("payload"))
if err == nil || err.Error() != "send: boom" {
t.Fatalf("Write() error = %v", err)
}
}
func TestCloseMakesReadReturnEOF(t *testing.T) {
cipher := newTestCipher(t)
conn := New(&stubLink{canSend: true}, cipher)
done := make(chan struct{})
go func() {
defer close(done)
buf := make([]byte, 4)
n, err := conn.Read(buf)
if !errors.Is(err, io.EOF) || n != 0 {
t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err)
}
}()
time.Sleep(10 * time.Millisecond)
if err := conn.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("Read() did not unblock after Close")
}
}

View File

@@ -0,0 +1,107 @@
package names
import (
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
func TestParseEmbedded(t *testing.T) {
got := parseEmbedded(" Alice \n\n Bob\n")
want := []string{"Alice", "Bob"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("parseEmbedded() = %#v, want %#v", got, want)
}
}
func TestLoadNames(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "names.txt")
if err := os.WriteFile(path, []byte(" Alice \n\nBob\n"), 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
got, err := loadNames(path)
if err != nil {
t.Fatalf("loadNames() error = %v", err)
}
want := []string{"Alice", "Bob"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("loadNames() = %#v, want %#v", got, want)
}
}
func TestLoadNameFilesOverridesGlobals(t *testing.T) {
oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...)
t.Cleanup(func() {
firstNames = oldFirst
lastNames = oldLast
})
dir := t.TempDir()
first := filepath.Join(dir, "first.txt")
last := filepath.Join(dir, "last.txt")
if err := os.WriteFile(first, []byte("Neo\n"), 0o600); err != nil {
t.Fatalf("WriteFile(first) error = %v", err)
}
if err := os.WriteFile(last, []byte("Anderson\n"), 0o600); err != nil {
t.Fatalf("WriteFile(last) error = %v", err)
}
if err := LoadNameFiles(first, last); err != nil {
t.Fatalf("LoadNameFiles() error = %v", err)
}
if got := Generate(); got != "Neo Anderson" {
t.Fatalf("Generate() = %q, want %q", got, "Neo Anderson")
}
}
func TestGenerateFallsBackWhenNamesEmpty(t *testing.T) {
oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...)
t.Cleanup(func() {
firstNames = oldFirst
lastNames = oldLast
})
firstNames = nil
lastNames = nil
if got := Generate(); got != "anonymous user" {
t.Fatalf("Generate() = %q, want anonymous user", got)
}
}
func TestRandomIndexBounds(t *testing.T) {
for i := 0; i < 20; i++ {
got := randomIndex(2)
if got < 0 || got > 1 {
t.Fatalf("randomIndex(2) = %d, out of range", got)
}
}
if got := randomIndex(0); got != 0 {
t.Fatalf("randomIndex(0) = %d, want 0", got)
}
}
func TestLoadNameFilesIgnoresMissingFiles(t *testing.T) {
oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...)
t.Cleanup(func() {
firstNames = oldFirst
lastNames = oldLast
})
firstNames = []string{"Kept"}
lastNames = []string{"Value"}
if err := LoadNameFiles("missing-first", "missing-last"); err != nil {
t.Fatalf("LoadNameFiles() error = %v", err)
}
got := Generate()
if !strings.Contains(got, "Kept") || !strings.Contains(got, "Value") {
t.Fatalf("Generate() = %q, want preserved names", got)
}
}

View File

@@ -0,0 +1,142 @@
package protect
import (
"context"
"errors"
"net"
"net/http"
"syscall"
"testing"
"time"
)
type rawConnStub struct {
controlFn func(func(uintptr)) error
}
func (r rawConnStub) Control(fn func(uintptr)) error {
if r.controlFn != nil {
return r.controlFn(fn)
}
fn(42)
return nil
}
func (r rawConnStub) Read(func(uintptr) bool) error { return nil }
func (r rawConnStub) Write(func(uintptr) bool) error { return nil }
func TestControlFuncWithoutProtector(t *testing.T) {
old := Protector
Protector = nil
t.Cleanup(func() { Protector = old })
if err := controlFunc("tcp4", "", rawConnStub{}); err != nil {
t.Fatalf("controlFunc() error = %v", err)
}
}
func TestControlFuncWithProtector(t *testing.T) {
old := Protector
t.Cleanup(func() { Protector = old })
called := 0
Protector = func(fd int) bool {
called++
if fd != 42 {
t.Fatalf("Protector fd = %d, want 42", fd)
}
return true
}
if err := controlFunc("tcp4", "", rawConnStub{}); err != nil {
t.Fatalf("controlFunc() error = %v", err)
}
if called != 1 {
t.Fatalf("Protector calls = %d, want 1", called)
}
Protector = func(int) bool { return false }
err := controlFunc("tcp4", "", rawConnStub{})
var opErr *net.OpError
if !errors.As(err, &opErr) || opErr.Op != "protect" {
t.Fatalf("controlFunc() error = %v, want protect op error", err)
}
}
func TestControlFuncWrapsControlError(t *testing.T) {
old := Protector
Protector = func(int) bool { return true }
t.Cleanup(func() { Protector = old })
err := controlFunc("tcp4", "", rawConnStub{
controlFn: func(func(uintptr)) error { return errors.New("boom") },
})
if err == nil || err.Error() != "control failed: boom" {
t.Fatalf("controlFunc() error = %v", err)
}
}
func TestNewDialerAndHTTPClient(t *testing.T) {
dialer := NewDialer()
if dialer.Timeout != 10*time.Second || dialer.KeepAlive != 30*time.Second || dialer.Control == nil {
t.Fatalf("NewDialer() = %+v", dialer)
}
client := NewHTTPClient()
tr, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
}
if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second ||
tr.ResponseHeaderTimeout != 10*time.Second {
t.Fatalf("transport = %+v", tr)
}
}
func TestDialContextAndProxyDialer(t *testing.T) {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen() error = %v", err)
}
defer func() { _ = ln.Close() }()
accepted := make(chan struct{}, 2)
go func() {
for i := 0; i < 2; i++ {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
accepted <- struct{}{}
}
}()
conn, err := DialContext(context.Background(), "tcp4", ln.Addr().String())
if err != nil {
t.Fatalf("DialContext() error = %v", err)
}
_ = conn.Close()
proxyConn, err := NewProxyDialer().Dial("tcp4", ln.Addr().String())
if err != nil {
t.Fatalf("ProxyDialer.Dial() error = %v", err)
}
_ = proxyConn.Close()
<-accepted
<-accepted
}
func TestDialFailuresAreWrapped(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
if _, err := DialContext(ctx, "tcp4", "127.0.0.1:1"); err == nil {
t.Fatal("DialContext() unexpectedly succeeded")
}
if _, err := NewProxyDialer().Dial("tcp4", "127.0.0.1:1"); err == nil {
t.Fatal("ProxyDialer.Dial() unexpectedly succeeded")
}
}
var _ syscall.RawConn = rawConnStub{}

View File

@@ -13,10 +13,9 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/protect"
)
const (
apiBase = "https://bk.salutejazz.ru"
authTypeAnonymous = "ANONYMOUS"
)
const authTypeAnonymous = "ANONYMOUS"
var apiBase = "https://bk.salutejazz.ru" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest.
// RoomInfo contains connection details for a SaluteJazz room.
type RoomInfo struct {

View File

@@ -0,0 +1,141 @@
package jazz
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func withJazzAPIServer(t *testing.T, h http.Handler) string {
t.Helper()
old := apiBase
srv := httptest.NewServer(h)
t.Cleanup(func() {
apiBase = old
srv.Close()
})
apiBase = srv.URL
return srv.URL
}
func TestCreateMeetingAndPreconnect(t *testing.T) {
withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Jazz-AuthType") != authTypeAnonymous {
t.Fatalf("missing auth header: %v", r.Header)
}
switch r.URL.Path {
case "/room/create-meeting":
if r.Method != http.MethodPost {
t.Fatalf("create method = %s", r.Method)
}
_ = json.NewEncoder(w).Encode(createResponse{RoomID: "room-1", Password: "pass"})
case "/room/room-1/preconnect":
if r.Method != http.MethodPost {
t.Fatalf("preconnect method = %s", r.Method)
}
_ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"})
default:
http.NotFound(w, r)
}
}))
headers := map[string]string{
"X-Jazz-AuthType": authTypeAnonymous,
"Content-Type": "application/json",
}
created, err := createMeeting(context.Background(), headers)
if err != nil {
t.Fatalf("createMeeting() error = %v", err)
}
if created.RoomID != "room-1" || created.Password != "pass" {
t.Fatalf("createMeeting() = %+v", created)
}
connector, err := preconnect(context.Background(), "room-1", "pass", headers)
if err != nil {
t.Fatalf("preconnect() error = %v", err)
}
if connector != "wss://connector" {
t.Fatalf("preconnect() = %q", connector)
}
}
func TestCreateRoomAndJoinRoom(t *testing.T) {
withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/room/create-meeting":
_ = json.NewEncoder(w).Encode(createResponse{RoomID: "new-room", Password: "new-pass"})
case "/room/new-room/preconnect", "/room/existing/preconnect":
_ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"})
default:
http.NotFound(w, r)
}
}))
room, err := createRoom(context.Background())
if err != nil {
t.Fatalf("createRoom() error = %v", err)
}
if room.RoomID != "new-room" || room.Password != "new-pass" || room.ConnectorURL != "wss://connector" {
t.Fatalf("createRoom() = %+v", room)
}
room, err = joinRoom(context.Background(), "existing", "secret")
if err != nil {
t.Fatalf("joinRoom() error = %v", err)
}
if room.RoomID != "existing" || room.Password != "secret" || room.ConnectorURL != "wss://connector" {
t.Fatalf("joinRoom() = %+v", room)
}
}
func TestJazzAPIErrors(t *testing.T) {
withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "create-meeting"):
http.Error(w, "bad", http.StatusTeapot)
default:
http.Error(w, "bad", http.StatusInternalServerError)
}
}))
if _, err := createMeeting(context.Background(), nil); !errors.Is(err, errCreateRoomFailed) {
t.Fatalf("createMeeting() error = %v, want %v", err, errCreateRoomFailed)
}
if _, err := preconnect(context.Background(), "room", "pass", nil); !errors.Is(err, errPreconnectFailed) {
t.Fatalf("preconnect() error = %v, want %v", err, errPreconnectFailed)
}
}
func TestNewPeerUsesRoomAPI(t *testing.T) {
withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/room/create-meeting":
_ = json.NewEncoder(w).Encode(createResponse{RoomID: "new-room", Password: "new-pass"})
case "/room/new-room/preconnect", "/room/existing/preconnect":
_ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"})
default:
http.NotFound(w, r)
}
}))
created, err := NewPeer(context.Background(), "any", "peer", nil)
if err != nil {
t.Fatalf("NewPeer(create) error = %v", err)
}
if created.roomInfo.RoomID != "new-room" {
t.Fatalf("created room = %+v", created.roomInfo)
}
joined, err := NewPeer(context.Background(), "existing:secret", "peer", nil)
if err != nil {
t.Fatalf("NewPeer(join) error = %v", err)
}
if joined.roomInfo.RoomID != "existing" || joined.roomInfo.Password != "secret" {
t.Fatalf("joined room = %+v", joined.roomInfo)
}
}

View File

@@ -0,0 +1,70 @@
package jazz
import (
"bytes"
"errors"
"io"
"testing"
)
func TestDataPacketRoundTrip(t *testing.T) {
payload := []byte("hello jazz")
raw := EncodeDataPacket(payload)
got, ok := DecodeDataPacket(raw)
if !ok {
t.Fatal("DecodeDataPacket() ok = false")
}
if !bytes.Equal(got, payload) {
t.Fatalf("DecodeDataPacket() = %q, want %q", got, payload)
}
}
func TestDecodeDataPacketRejectsMalformedPackets(t *testing.T) {
tests := [][]byte{
nil,
{0xff},
encodeField(1, 0, encodeVarint(0)),
{byte(2<<3 | 2), 10, 1},
{byte(3<<3 | 7), 0},
}
for _, raw := range tests {
if payload, ok := DecodeDataPacket(raw); ok {
t.Fatalf("DecodeDataPacket(%v) = (%q, true), want false", raw, payload)
}
}
}
func TestParseFieldsSkipsSupportedNonTargetWireTypes(t *testing.T) {
data := encodeField(1, 0, encodeVarint(150))
data = append(data, encodeField(3, 1, []byte("12345678"))...)
data = append(data, encodeField(4, 5, []byte("1234"))...)
data = append(data, encodeField(2, 2, []byte("target"))...)
got, ok := parseFields(data, 2)
if !ok || string(got) != "target" {
t.Fatalf("parseFields() = (%q, %v), want target", got, ok)
}
}
func TestByteReader(t *testing.T) {
r := &byteReader{data: []byte{1, 2, 3}}
b, err := r.ReadByte()
if err != nil || b != 1 {
t.Fatalf("ReadByte() = (%d, %v), want (1, nil)", b, err)
}
buf := make([]byte, 4)
n, err := r.Read(buf)
if err != nil || n != 2 || !bytes.Equal(buf[:n], []byte{2, 3}) {
t.Fatalf("Read() = (%d, %v, %v), want two bytes", n, err, buf[:n])
}
if _, err := r.ReadByte(); !errors.Is(err, io.EOF) {
t.Fatalf("ReadByte() error = %v, want EOF", err)
}
if n, err := r.Read(buf); !errors.Is(err, io.EOF) || n != 0 {
t.Fatalf("Read() = (%d, %v), want (0, EOF)", n, err)
}
}

View File

@@ -0,0 +1,112 @@
package jazz
import (
"context"
"errors"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
)
func TestPeerStateHelpers(t *testing.T) {
p := &Peer{
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sessionCloseCh: make(chan struct{}),
sendQueue: make(chan []byte, 1),
subscriberConn: make(chan struct{}),
publisherConn: make(chan struct{}),
}
p.resetMediaState()
if p.subscriberReady.Load() || p.publisherReady.Load() || p.subscriberConn == nil || p.publisherConn == nil {
t.Fatal("resetMediaState() did not reset readiness")
}
if p.hasLocalVideoTracks() {
t.Fatal("hasLocalVideoTracks() = true without tracks")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if !p.hasLocalVideoTracks() {
t.Fatal("hasLocalVideoTracks() = false after AddVideoTrack")
}
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if p.videoTrackHandler() == nil {
t.Fatal("videoTrackHandler() = nil")
}
cfg := defaultWebRTCConfig()
if cfg.SDPSemantics != webrtc.SDPSemanticsUnifiedPlan || cfg.BundlePolicy != webrtc.BundlePolicyMaxBundle {
t.Fatalf("defaultWebRTCConfig() = %+v", cfg)
}
if p.buildAPI() == nil {
t.Fatal("buildAPI() returned nil")
}
}
func TestPeerCallbacksQueueReconnectAndClose(t *testing.T) {
p := &Peer{
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sessionCloseCh: make(chan struct{}),
sendQueue: make(chan []byte, 1),
}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
p.SetShouldReconnect(func() bool { return true })
p.SetEndedCallback(func(string) {})
if p.onReconnect == nil || p.shouldReconnect == nil || p.onEnded == nil {
t.Fatal("callbacks were not stored")
}
p.queueReconnect()
select {
case <-p.reconnectCh:
default:
t.Fatal("queueReconnect() did not enqueue")
}
p.SetShouldReconnect(func() bool { return false })
p.queueReconnect()
select {
case <-p.reconnectCh:
t.Fatal("queueReconnect() enqueued despite policy=false")
default:
}
done := make(chan struct{})
go func() {
p.WatchConnection(context.Background())
close(done)
}()
if err := p.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
<-done
if err := p.Send([]byte("closed")); !errors.Is(err, provider.ErrDataChannelNotReady) {
t.Fatalf("Send() error = %v, want datachannel not ready", err)
}
}
func TestPeerCanSendVideoOnlyModes(t *testing.T) {
p := &Peer{sendQueue: make(chan []byte, 1)}
p.subscriberReady.Store(true)
if !p.CanSend() {
t.Fatal("CanSend() = false for subscriber-ready peer without local video")
}
_ = p.AddVideoTrack(nil)
if p.CanSend() {
t.Fatal("CanSend() = true with local video but publisher not ready")
}
p.publisherReady.Store(true)
if !p.CanSend() {
t.Fatal("CanSend() = false with subscriber and publisher ready")
}
p.closed.Store(true)
if p.CanSend() {
t.Fatal("CanSend() = true for closed peer")
}
}

View File

@@ -0,0 +1,51 @@
package jazz
import (
"context"
"errors"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
)
func TestJazzProviderForwardsPeerMethods(t *testing.T) {
peer := &Peer{
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sessionCloseCh: make(chan struct{}),
sendQueue: make(chan []byte, 1),
}
p := &jazzProvider{peer: peer}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
p.SetShouldReconnect(func() bool { return true })
p.SetEndedCallback(func(string) {})
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil {
t.Fatal("callbacks were not forwarded")
}
if p.GetSendQueue() != peer.sendQueue {
t.Fatal("GetSendQueue() did not forward")
}
if p.GetBufferedAmount() != 0 {
t.Fatal("GetBufferedAmount() != 0 with nil datachannel")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if err := p.Send([]byte("x")); !errors.Is(err, provider.ErrDataChannelNotReady) {
t.Fatalf("Send() error = %v, want datachannel not ready", err)
}
done := make(chan struct{})
go func() {
p.WatchConnection(context.Background())
close(done)
}()
if err := p.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
<-done
}

View File

@@ -0,0 +1,75 @@
package provider
import (
"context"
"errors"
"reflect"
"testing"
"github.com/pion/webrtc/v4"
)
type stubProvider struct{}
func (s *stubProvider) Connect(context.Context) error { return nil }
func (s *stubProvider) Send([]byte) error { return nil }
func (s *stubProvider) Close() error { return nil }
func (s *stubProvider) SetReconnectCallback(func(*webrtc.DataChannel)) {}
func (s *stubProvider) SetShouldReconnect(func() bool) {}
func (s *stubProvider) SetEndedCallback(func(string)) {}
func (s *stubProvider) WatchConnection(context.Context) {}
func (s *stubProvider) CanSend() bool { return true }
func (s *stubProvider) GetSendQueue() chan []byte { return nil }
func (s *stubProvider) GetBufferedAmount() uint64 { return 0 }
func snapshotProviderRegistry() map[string]Factory {
out := make(map[string]Factory, len(registry))
for k, v := range registry {
out[k] = v
}
return out
}
func restoreProviderRegistry(src map[string]Factory) {
registry = make(map[string]Factory, len(src))
for k, v := range src {
registry[k] = v
}
}
func TestNewAndAvailable(t *testing.T) {
old := snapshotProviderRegistry()
t.Cleanup(func() { restoreProviderRegistry(old) })
called := false
Register("test-provider", func(_ context.Context, cfg Config) (Provider, error) {
called = cfg.Name == "peer"
return &stubProvider{}, nil
})
got, err := New(context.Background(), "test-provider", Config{Name: "peer"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if !called {
t.Fatal("factory did not receive config")
}
if _, ok := got.(*stubProvider); !ok {
t.Fatalf("New() returned %T, want *stubProvider", got)
}
if !reflect.DeepEqual(Available(), []string{"test-provider"}) {
t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-provider"})
}
}
func TestNewReturnsErrProviderNotFound(t *testing.T) {
old := snapshotProviderRegistry()
t.Cleanup(func() { restoreProviderRegistry(old) })
registry = map[string]Factory{}
_, err := New(context.Background(), "missing", Config{})
if !errors.Is(err, ErrProviderNotFound) {
t.Fatalf("New() error = %v, want %v", err, ErrProviderNotFound)
}
}

View File

@@ -13,14 +13,14 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/protect"
)
const apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost"
var apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest.
var ErrAPI = errors.New("api error") //nolint:revive
type ConnectionInfo struct { //nolint:revive
RoomID string `json:"room_id"` //nolint:tagliatelle
PeerID string `json:"peer_id"` //nolint:tagliatelle
Credentials string `json:"credentials"` //nolint:tagliatelle
RoomID string `json:"room_id"` //nolint:tagliatelle
PeerID string `json:"peer_id"` //nolint:tagliatelle
Credentials string `json:"credentials"` //nolint:tagliatelle
ClientConfig struct {
MediaServerURL string `json:"media_server_url"` //nolint:tagliatelle
} `json:"client_configuration"` //nolint:tagliatelle

View File

@@ -0,0 +1,83 @@
package telemost
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func withTelemostAPIServer(t *testing.T, h http.Handler) {
t.Helper()
old := apiBase
srv := httptest.NewServer(h)
t.Cleanup(func() {
apiBase = old
srv.Close()
})
apiBase = srv.URL
}
func TestGetConnectionInfo(t *testing.T) {
withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("method = %s", r.Method)
}
if !strings.Contains(r.URL.EscapedPath(), "/conferences/room%2Fid/connection") {
t.Fatalf("path = %q escaped=%q", r.URL.Path, r.URL.EscapedPath())
}
if r.URL.Query().Get("display_name") != "peer" {
t.Fatalf("display_name query = %q", r.URL.Query().Get("display_name"))
}
_ = json.NewEncoder(w).Encode(ConnectionInfo{
RoomID: "room",
PeerID: "peer-id",
Credentials: "creds",
})
}))
info, err := GetConnectionInfo(context.Background(), "room/id", "peer")
if err != nil {
t.Fatalf("GetConnectionInfo() error = %v", err)
}
if info.RoomID != "room" || info.PeerID != "peer-id" || info.Credentials != "creds" {
t.Fatalf("GetConnectionInfo() = %+v", info)
}
}
func TestGetConnectionInfoErrors(t *testing.T) {
withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad", http.StatusForbidden)
}))
if _, err := GetConnectionInfo(context.Background(), "room", "peer"); !errors.Is(err, ErrAPI) {
t.Fatalf("GetConnectionInfo() error = %v, want %v", err, ErrAPI)
}
withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("{"))
}))
if _, err := GetConnectionInfo(context.Background(), "room", "peer"); err == nil {
t.Fatal("GetConnectionInfo() unexpectedly accepted bad json")
}
}
func TestTelemostNewPeerUsesConnectionInfo(t *testing.T) {
withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(ConnectionInfo{
RoomID: "room",
PeerID: "peer-id",
Credentials: "creds",
})
}))
p, err := NewPeer(context.Background(), "room", "name", nil)
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
if p.roomURL != "room" || p.name != "name" || p.conn.PeerID != "peer-id" || p.sendQueue == nil {
t.Fatalf("NewPeer() = %+v", p)
}
}

View File

@@ -0,0 +1,195 @@
package telemost
import (
"testing"
"time"
"github.com/pion/webrtc/v4"
)
func TestCloseSignal(t *testing.T) {
closeSignal(nil)
ch := make(chan struct{})
closeSignal(ch)
select {
case <-ch:
default:
t.Fatal("closeSignal() did not close channel")
}
closeSignal(ch)
}
func TestTrafficShapeAndDelay(t *testing.T) {
p := &Peer{}
p.SetTrafficShape(TrafficShape{MaxMessageSize: -1, MinDelay: 5 * time.Millisecond, MaxDelay: 2 * time.Millisecond})
if p.trafficShape.MaxMessageSize != realDataChannelMessageLimit {
t.Fatalf("MaxMessageSize = %d, want default", p.trafficShape.MaxMessageSize)
}
if p.trafficShape.MaxDelay != p.trafficShape.MinDelay {
t.Fatalf("MaxDelay = %v, want %v", p.trafficShape.MaxDelay, p.trafficShape.MinDelay)
}
if got := p.calculateDelay(); got != 5*time.Millisecond {
t.Fatalf("calculateDelay() = %v, want 5ms", got)
}
p.SetTrafficShape(TrafficShape{MaxMessageSize: 10, MinDelay: time.Millisecond, MaxDelay: 4 * time.Millisecond})
for i := 0; i < 20; i++ {
got := p.calculateDelay()
if got < time.Millisecond || got >= 4*time.Millisecond {
t.Fatalf("calculateDelay() = %v, out of range", got)
}
}
}
func TestICEParsingFiltersTURN(t *testing.T) {
if isNonTURNURL("") || isNonTURNURL("turn:host") || isNonTURNURL("turns:host") {
t.Fatal("isNonTURNURL accepted empty or TURN URL")
}
if !isNonTURNURL("stun:host") {
t.Fatal("isNonTURNURL rejected STUN URL")
}
urls := parseICEURLs(map[string]interface{}{"urls": []interface{}{"turn:x", "stun:a", 123, "turns:y"}})
if len(urls) != 1 || urls[0] != "stun:a" {
t.Fatalf("parseICEURLs(interface) = %v, want [stun:a]", urls)
}
urls = parseICEURLs(map[string]interface{}{"urls": []string{"stun:a", "turn:b"}})
if len(urls) != 1 || urls[0] != "stun:a" {
t.Fatalf("parseICEURLs(strings) = %v, want [stun:a]", urls)
}
}
func TestParseICEServer(t *testing.T) {
if _, ok := parseICEServer("bad"); ok {
t.Fatal("parseICEServer() accepted non-map")
}
if _, ok := parseICEServer(map[string]interface{}{"urls": []interface{}{"turn:x"}}); ok {
t.Fatal("parseICEServer() accepted TURN-only server")
}
ice, ok := parseICEServer(map[string]interface{}{
"urls": []interface{}{"stun:a", "turn:b"},
"username": "user",
"credential": "pass",
})
if !ok {
t.Fatal("parseICEServer() ok = false")
}
if len(ice.URLs) != 1 || ice.URLs[0] != "stun:a" || ice.Username != "user" || ice.Credential != "pass" {
t.Fatalf("parseICEServer() = %+v", ice)
}
}
func TestConferenceEndParsing(t *testing.T) {
for _, msg := range []map[string]interface{}{
{"conferenceClosed": true},
{"conference": map[string]interface{}{"state": "ENDED"}},
{"conferenceState": map[string]interface{}{"state": "terminated"}},
} {
if !isConferenceEndMessage(msg) {
t.Fatalf("isConferenceEndMessage(%v) = false", msg)
}
}
if isConferenceEndMessage(map[string]interface{}{"conference": map[string]interface{}{"state": "open"}}) {
t.Fatal("isConferenceEndMessage() accepted active conference")
}
for _, state := range []string{"closed", "ended", "finished", "terminated"} {
if !isEndedState(state) {
t.Fatalf("isEndedState(%q) = false", state)
}
}
if isEndedState("active") {
t.Fatal("isEndedState(active) = true")
}
}
func TestPeerSmallStateHelpers(t *testing.T) {
p := &Peer{
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, 2),
ackWaiters: make(map[string]chan struct{}),
}
p.SetEndedCallback(func(string) {})
if p.onEnded == nil {
t.Fatal("SetEndedCallback() did not store callback")
}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
if p.onReconnect == nil {
t.Fatal("SetReconnectCallback() did not store callback")
}
p.SetShouldReconnect(func() bool { return true })
if p.shouldReconnect == nil || !p.shouldReconnect() {
t.Fatal("SetShouldReconnect() did not store callback")
}
p.subscriberReady.Store(true)
if !p.CanSend() {
t.Fatal("CanSend() = false for subscriber-only ready peer")
}
p.closed.Store(true)
if p.CanSend() {
t.Fatal("CanSend() = true for closed peer")
}
ch := p.registerAckWaiter("uid-1")
p.resolveAck("uid-1")
select {
case <-ch:
default:
t.Fatal("resolveAck() did not close waiter")
}
if p.waitForAck("", make(chan struct{}), time.Millisecond) {
t.Fatal("waitForAck(empty uid) = true")
}
ch = p.registerAckWaiter("uid-2")
go p.resolveAck("uid-2")
if !p.waitForAck("uid-2", ch, time.Second) {
t.Fatal("waitForAck() = false after resolveAck")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if !p.hasLocalVideoTracks() {
t.Fatal("hasLocalVideoTracks() = false after AddVideoTrack")
}
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if p.videoTrackHandler() == nil {
t.Fatal("videoTrackHandler() = nil")
}
}
func TestTelemetryCfgParsing(t *testing.T) {
if _, _, ok := parseTelemetryCfg(map[string]interface{}{}); ok {
t.Fatal("parseTelemetryCfg() accepted missing config")
}
if _, _, ok := parseTelemetryCfg(map[string]interface{}{
"telemetryConfiguration": map[string]interface{}{},
}); ok {
t.Fatal("parseTelemetryCfg() accepted missing endpoint")
}
endpoint, interval, ok := parseTelemetryCfg(map[string]interface{}{
"telemetryConfiguration": map[string]interface{}{
"endpoint": "https://example.test/log",
"sendingInterval": float64(250),
},
})
if !ok || endpoint != "https://example.test/log" || interval != 250*time.Millisecond {
t.Fatalf("parseTelemetryCfg() = (%q, %v, %v)", endpoint, interval, ok)
}
endpoint, interval, ok = parseTelemetryCfg(map[string]interface{}{
"telemetryConfiguration": map[string]interface{}{
"url": "https://example.test/url",
},
})
if !ok || endpoint != "https://example.test/url" || interval != defaultTelemetryInterval {
t.Fatalf("parseTelemetryCfg(default) = (%q, %v, %v)", endpoint, interval, ok)
}
}

View File

@@ -0,0 +1,54 @@
package telemost
import (
"context"
"errors"
"testing"
"github.com/pion/webrtc/v4"
)
func TestTelemostProviderForwardsPeerMethods(t *testing.T) {
peer := &Peer{
reconnectCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
sendQueue: make(chan []byte, 1),
ackWaiters: make(map[string]chan struct{}),
}
p := &telemostProvider{peer: peer}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
p.SetShouldReconnect(func() bool { return true })
p.SetEndedCallback(func(string) {})
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil {
t.Fatal("callbacks were not forwarded")
}
if p.GetSendQueue() != peer.sendQueue {
t.Fatal("GetSendQueue() did not forward")
}
if p.GetBufferedAmount() != 0 {
t.Fatal("GetBufferedAmount() != 0 with nil datachannel")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if p.CanSend() {
t.Fatal("CanSend() = true for unready peer")
}
done := make(chan struct{})
go func() {
p.WatchConnection(context.Background())
close(done)
}()
if err := p.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
<-done
if err := p.Send([]byte("x")); !errors.Is(err, ErrDataChannelNotReady) {
t.Fatalf("Send() error = %v, want datachannel not ready", err)
}
}

View File

@@ -0,0 +1,84 @@
package telemost
import (
"testing"
"time"
)
func TestSessionReconnectAndEndedHelpers(t *testing.T) {
p := &Peer{
reconnectCh: make(chan struct{}, 2),
closeCh: make(chan struct{}),
keepAliveCh: make(chan struct{}),
sessionCloseCh: make(chan struct{}),
telemetryCh: make(chan struct{}, 1),
}
keepAliveCh, sessionCloseCh := p.resetSession()
if keepAliveCh == nil || sessionCloseCh == nil || keepAliveCh != p.keepAliveCh || sessionCloseCh != p.sessionCloseCh {
t.Fatal("resetSession() did not replace session channels")
}
p.subscriberReady.Store(true)
p.publisherReady.Store(true)
p.resetMediaState()
if p.subscriberReady.Load() || p.publisherReady.Load() || p.subscriberConn == nil || p.publisherConn == nil {
t.Fatal("resetMediaState() did not reset readiness")
}
p.queueReconnect()
select {
case <-p.reconnectCh:
default:
t.Fatal("queueReconnect() did not enqueue")
}
p.SetShouldReconnect(func() bool { return false })
p.queueReconnect()
select {
case <-p.reconnectCh:
t.Fatal("queueReconnect() enqueued despite policy=false")
default:
}
p.reconnectCh <- struct{}{}
p.reconnectCh <- struct{}{}
p.drainReconnectQueue()
select {
case <-p.reconnectCh:
t.Fatal("drainReconnectQueue() left queued item")
default:
}
p.telemetryActive.Store(true)
p.stopTelemetry()
select {
case <-p.telemetryCh:
default:
t.Fatal("stopTelemetry() did not signal active telemetry")
}
ended := ""
p.SetEndedCallback(func(reason string) { ended = reason })
p.signalEnded("done")
if !p.closed.Load() || ended != "done" {
t.Fatalf("signalEnded() closed=%v reason=%q", p.closed.Load(), ended)
}
}
func TestWaitForAckTimeoutAndClose(t *testing.T) {
p := &Peer{
closeCh: make(chan struct{}),
ackWaiters: make(map[string]chan struct{}),
}
ch := p.registerAckWaiter("timeout")
if p.waitForAck("timeout", ch, time.Millisecond) {
t.Fatal("waitForAck(timeout) = true")
}
ch = p.registerAckWaiter("closed")
close(p.closeCh)
if p.waitForAck("closed", ch, time.Second) {
t.Fatal("waitForAck(closeCh) = true")
}
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/protect"
)
const apiBase = "https://stream.wb.ru"
var apiBase = "https://stream.wb.ru" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest.
var (
errGuestRegister = errors.New("guest register failed")

View File

@@ -0,0 +1,123 @@
package wbstream
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func withWBAPIServer(t *testing.T, h http.Handler) {
t.Helper()
old := apiBase
srv := httptest.NewServer(h)
t.Cleanup(func() {
apiBase = old
srv.Close()
})
apiBase = srv.URL
}
func TestWBStreamAPIHappyPath(t *testing.T) {
withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/auth/api/v1/auth/user/guest-register":
if r.Method != http.MethodPost {
t.Fatalf("guest method = %s", r.Method)
}
_ = json.NewEncoder(w).Encode(guestRegisterResponse{AccessToken: "access"})
case "/api-room/api/v2/room":
if r.Header.Get("Authorization") != "Bearer access" {
t.Fatalf("room auth = %q", r.Header.Get("Authorization"))
}
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(createRoomResponse{RoomID: "room"})
case "/api-room/api/v1/room/room/join":
w.WriteHeader(http.StatusOK)
case "/api-room-manager/api/v1/room/room/token":
if r.URL.Query().Get("displayName") != "peer" {
t.Fatalf("displayName query = %q", r.URL.Query().Get("displayName"))
}
_ = json.NewEncoder(w).Encode(tokenResponse{RoomToken: "token"})
default:
http.NotFound(w, r)
}
}))
access, err := registerGuest(context.Background(), "peer")
if err != nil {
t.Fatalf("registerGuest() error = %v", err)
}
if access != "access" {
t.Fatalf("registerGuest() = %q", access)
}
room, err := createRoom(context.Background(), access)
if err != nil {
t.Fatalf("createRoom() error = %v", err)
}
if room != "room" {
t.Fatalf("createRoom() = %q", room)
}
if err := joinRoom(context.Background(), access, room); err != nil {
t.Fatalf("joinRoom() error = %v", err)
}
token, err := getToken(context.Background(), access, room, "peer")
if err != nil {
t.Fatalf("getToken() error = %v", err)
}
if token != "token" {
t.Fatalf("getToken() = %q", token)
}
}
func TestWBStreamAPIErrors(t *testing.T) {
withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad", http.StatusBadGateway)
}))
if _, err := registerGuest(context.Background(), "peer"); !errors.Is(err, errGuestRegister) {
t.Fatalf("registerGuest() error = %v, want %v", err, errGuestRegister)
}
if _, err := createRoom(context.Background(), "access"); !errors.Is(err, errCreateRoom) {
t.Fatalf("createRoom() error = %v, want %v", err, errCreateRoom)
}
if err := joinRoom(context.Background(), "access", "room"); !errors.Is(err, errJoinRoom) {
t.Fatalf("joinRoom() error = %v, want %v", err, errJoinRoom)
}
if _, err := getToken(context.Background(), "access", "room", "peer"); !errors.Is(err, errGetToken) {
t.Fatalf("getToken() error = %v, want %v", err, errGetToken)
}
}
func TestWBStreamGetRoomToken(t *testing.T) {
withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/auth/api/v1/auth/user/guest-register":
_ = json.NewEncoder(w).Encode(guestRegisterResponse{AccessToken: "access"})
case "/api-room/api/v2/room":
_ = json.NewEncoder(w).Encode(createRoomResponse{RoomID: "created"})
case "/api-room/api/v1/room/created/join":
w.WriteHeader(http.StatusOK)
case "/api-room-manager/api/v1/room/created/token":
_ = json.NewEncoder(w).Encode(tokenResponse{RoomToken: "token"})
default:
http.NotFound(w, r)
}
}))
p, err := NewPeer(context.Background(), "any", "peer", nil)
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
token, err := p.getRoomToken(context.Background())
if err != nil {
t.Fatalf("getRoomToken() error = %v", err)
}
if token != "token" {
t.Fatalf("getRoomToken() = %q", token)
}
}

View File

@@ -0,0 +1,76 @@
package wbstream
import (
"context"
"errors"
"testing"
"github.com/pion/webrtc/v4"
)
func TestNewPeerAndSimpleAccessors(t *testing.T) {
p, err := NewPeer(context.Background(), "room", "name", func([]byte) {})
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
if p.roomURL != "room" || p.name != "name" || p.sendQueue == nil || p.done == nil {
t.Fatalf("NewPeer() = %+v", p)
}
if p.GetSendQueue() != p.sendQueue {
t.Fatal("GetSendQueue() did not return sendQueue")
}
if p.GetBufferedAmount() != 0 {
t.Fatal("GetBufferedAmount() != 0")
}
if p.CanSend() {
t.Fatal("CanSend() = true without room")
}
}
func TestSendQueueAndClose(t *testing.T) {
p, err := NewPeer(context.Background(), "room", "name", nil)
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
p.sendQueue = make(chan []byte, 1)
if err := p.Send([]byte("one")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if err := p.Send([]byte("two")); !errors.Is(err, ErrSendQueueFull) {
t.Fatalf("Send() error = %v, want %v", err, ErrSendQueueFull)
}
if err := p.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
if err := p.Send([]byte("closed")); !errors.Is(err, ErrPeerClosed) {
t.Fatalf("Send() error = %v, want %v", err, ErrPeerClosed)
}
if err := p.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}
func TestCallbacksAndVideoTrackStorage(t *testing.T) {
p, err := NewPeer(context.Background(), "room", "name", nil)
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
p.SetShouldReconnect(func() bool { return true })
p.SetEndedCallback(func(string) {})
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
p.WatchConnection(context.Background())
if p.onReconnect == nil || p.shouldReconnect == nil || p.onEnded == nil || p.onVideoTrack == nil {
t.Fatal("callbacks were not stored")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if len(p.videoTracks) != 1 {
t.Fatalf("videoTracks len = %d, want 1", len(p.videoTracks))
}
}

View File

@@ -0,0 +1,49 @@
package wbstream
import (
"context"
"errors"
"testing"
"github.com/pion/webrtc/v4"
)
func TestWBStreamProviderForwardsPeerMethods(t *testing.T) {
peer, err := NewPeer(context.Background(), "room", "name", nil)
if err != nil {
t.Fatalf("NewPeer() error = %v", err)
}
p := &wbStreamProvider{peer: peer}
p.SetReconnectCallback(func(*webrtc.DataChannel) {})
p.SetShouldReconnect(func() bool { return true })
p.SetEndedCallback(func(string) {})
p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {})
if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil {
t.Fatal("callbacks were not forwarded")
}
if p.GetSendQueue() != peer.sendQueue {
t.Fatal("GetSendQueue() did not forward")
}
if p.GetBufferedAmount() != 0 {
t.Fatal("GetBufferedAmount() != 0")
}
if err := p.AddVideoTrack(nil); err != nil {
t.Fatalf("AddVideoTrack(nil) error = %v", err)
}
if p.CanSend() {
t.Fatal("CanSend() = true without LiveKit room")
}
p.WatchConnection(context.Background())
if err := p.Send([]byte("x")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if err := p.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
if err := p.Send([]byte("x")); !errors.Is(err, ErrPeerClosed) {
t.Fatalf("Send() error = %v, want peer closed", err)
}
}

View File

@@ -0,0 +1,343 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"strings"
"testing"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/xtaci/smux"
)
func TestSetupCipher(t *testing.T) {
keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"
cipher, err := setupCipher(keyHex)
if err != nil {
t.Fatalf("setupCipher() error = %v", err)
}
if cipher == nil {
t.Fatal("setupCipher() returned nil cipher")
}
}
func TestSetupCipherRejectsBadInput(t *testing.T) {
if _, err := setupCipher(""); !errors.Is(err, ErrKeyRequired) {
t.Fatalf("setupCipher() error = %v, want %v", err, ErrKeyRequired)
}
if _, err := setupCipher("zz"); err == nil {
t.Fatal("setupCipher() unexpectedly succeeded for bad hex")
}
if _, err := setupCipher("00"); !errors.Is(err, ErrKeySize) {
t.Fatalf("setupCipher() error = %v, want ErrKeySize", err)
}
}
func TestSmuxConfig(t *testing.T) {
cfg := smuxConfig()
if cfg.Version != 2 || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig() = %+v", cfg)
}
}
func TestParseConnectRequest(t *testing.T) {
buf, err := json.Marshal(ConnectRequest{
Cmd: "connect",
ClientID: "client-1",
Addr: "example.com",
Port: 443,
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
req, ok := parseConnectRequest(buf)
if !ok {
t.Fatal("parseConnectRequest() returned ok=false")
}
if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 {
t.Fatalf("parseConnectRequest() = %+v", req)
}
if _, ok := parseConnectRequest([]byte("not-json")); ok {
t.Fatal("parseConnectRequest() unexpectedly accepted invalid json")
}
if _, ok := parseConnectRequest([]byte(`{"cmd":"other"}`)); ok {
t.Fatal("parseConnectRequest() unexpectedly accepted wrong command")
}
}
func TestAuthorizeRequest(t *testing.T) {
s := &Server{clientID: "client-1"}
if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) {
t.Fatal("authorizeRequest() rejected valid client")
}
if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) {
t.Fatal("authorizeRequest() accepted wrong client")
}
}
func TestSocks5ConnectSuccess(t *testing.T) {
s := &Server{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
done <- s.socks5Connect(server, "example.com", 443)
}()
auth := make([]byte, 3)
if _, err := io.ReadFull(client, auth); err != nil {
t.Fatalf("ReadFull(auth) error = %v", err)
}
if !bytes.Equal(auth, []byte{5, 1, 0}) {
t.Fatalf("auth request = %v", auth)
}
if _, err := client.Write([]byte{5, 0}); err != nil {
t.Fatalf("Write(auth resp) error = %v", err)
}
req := make([]byte, 18)
if _, err := io.ReadFull(client, req); err != nil {
t.Fatalf("ReadFull(connect req) error = %v", err)
}
if req[0] != 5 || req[1] != 1 || req[3] != 3 || req[4] != byte(len("example.com")) {
t.Fatalf("connect request header = %v", req[:5])
}
if string(req[5:16]) != "example.com" {
t.Fatalf("connect request addr = %q", req[5:16])
}
if req[16] != 0x01 || req[17] != 0xbb {
t.Fatalf("connect request port bytes = %v", req[16:18])
}
if _, err := client.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil {
t.Fatalf("Write(connect resp) error = %v", err)
}
if err := <-done; err != nil {
t.Fatalf("socks5Connect() error = %v", err)
}
}
func TestSocks5ConnectErrors(t *testing.T) {
s := &Server{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
done := make(chan error, 1)
go func() {
done <- s.socks5Connect(server, "example.com", 443)
}()
auth := make([]byte, 3)
if _, err := io.ReadFull(client, auth); err != nil {
t.Fatalf("ReadFull(auth) error = %v", err)
}
if _, err := client.Write([]byte{5, 1}); err != nil {
t.Fatalf("Write(auth resp) error = %v", err)
}
if err := <-done; !errors.Is(err, ErrSocks5AuthFailed) {
t.Fatalf("socks5Connect() error = %v, want %v", err, ErrSocks5AuthFailed)
}
server2, client2 := net.Pipe()
defer func() {
_ = server2.Close()
_ = client2.Close()
}()
done = make(chan error, 1)
go func() {
done <- s.socks5Connect(server2, "example.com", 443)
}()
if _, err := io.ReadFull(client2, auth); err != nil {
t.Fatalf("ReadFull(auth2) error = %v", err)
}
if _, err := client2.Write([]byte{5, 0}); err != nil {
t.Fatalf("Write(auth2 resp) error = %v", err)
}
req := make([]byte, 18)
if _, err := io.ReadFull(client2, req); err != nil {
t.Fatalf("ReadFull(req2) error = %v", err)
}
if _, err := client2.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil {
t.Fatalf("Write(connect2 resp) error = %v", err)
}
if err := <-done; !errors.Is(err, ErrSocks5ConnectFailed) {
t.Fatalf("socks5Connect() error = %v, want %v", err, ErrSocks5ConnectFailed)
}
}
func TestSetupResolver(t *testing.T) {
s := &Server{dnsServer: "127.0.0.1:53"}
s.setupResolver()
if s.resolver == nil || !s.resolver.PreferGo || s.resolver.Dial == nil {
t.Fatalf("setupResolver() = %+v", s.resolver)
}
}
func TestOnDataWithNilConn(t *testing.T) {
s := &Server{}
s.onData([]byte("ignored"))
}
type serverLinkStub struct {
closed bool
}
func (s *serverLinkStub) Connect(context.Context) error { return nil }
func (s *serverLinkStub) Send([]byte) error { return nil }
func (s *serverLinkStub) Close() error { s.closed = true; return nil }
func (s *serverLinkStub) SetReconnectCallback(func()) {}
func (s *serverLinkStub) SetShouldReconnect(func() bool) {}
func (s *serverLinkStub) SetEndedCallback(func(string)) {}
func (s *serverLinkStub) WatchConnection(context.Context) {}
func (s *serverLinkStub) CanSend() bool { return true }
func TestShutdownClosesLinkAndConn(t *testing.T) {
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
ln := &serverLinkStub{}
s := &Server{
ln: ln,
cipher: cipher,
conn: muxconn.New(ln, cipher),
}
s.shutdown()
if !ln.closed {
t.Fatal("shutdown() did not close link")
}
}
func TestDialWithoutProxy(t *testing.T) {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen() error = %v", err)
}
defer func() { _ = ln.Close() }()
done := make(chan struct{})
go func() {
conn, err := ln.Accept()
if err == nil {
_ = conn.Close()
close(done)
}
}()
tcpAddr := ln.Addr().(*net.TCPAddr)
s := &Server{resolver: net.DefaultResolver}
conn, err := s.dial(ConnectRequest{Addr: "127.0.0.1", Port: tcpAddr.Port})
if err != nil {
t.Fatalf("dial() error = %v", err)
}
_ = conn.Close()
<-done
}
func TestDialProxyError(t *testing.T) {
s := &Server{socksProxyAddr: "127.0.0.1", socksProxyPort: 1}
if _, err := s.dial(ConnectRequest{Addr: "example.com", Port: 443}); err == nil || !strings.Contains(err.Error(), "failed to dial proxy") {
t.Fatalf("dial() error = %v", err)
}
}
func TestSocks5ConnectTruncatesLongDomain(t *testing.T) {
s := &Server{}
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
longHost := strings.Repeat("a", 300)
done := make(chan error, 1)
go func() {
done <- s.socks5Connect(server, longHost, 443)
}()
auth := make([]byte, 3)
if _, err := io.ReadFull(client, auth); err != nil {
t.Fatalf("ReadFull(auth) error = %v", err)
}
if _, err := client.Write([]byte{5, 0}); err != nil {
t.Fatalf("Write(auth resp) error = %v", err)
}
req := make([]byte, 262)
if _, err := io.ReadFull(client, req); err != nil {
t.Fatalf("ReadFull(connect req) error = %v", err)
}
if req[4] != 255 {
t.Fatalf("domain len byte = %d, want 255", req[4])
}
if _, err := client.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil {
t.Fatalf("Write(connect resp) error = %v", err)
}
if err := <-done; err != nil {
t.Fatalf("socks5Connect() error = %v", err)
}
}
func TestHandleStreamRejectsWrongClientID(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
done := make(chan struct{})
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
(&Server{clientID: "expected"}).handleStream(context.Background(), stream)
}
close(done)
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
req, err := json.Marshal(ConnectRequest{
Cmd: "connect",
ClientID: "wrong",
Addr: "example.com",
Port: 443,
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
if _, err := stream.Write(req); err != nil {
t.Fatalf("Write() error = %v", err)
}
<-done
}

View File

@@ -0,0 +1,139 @@
package datachannel
import (
"context"
"errors"
"testing"
"github.com/openlibrecommunity/olcrtc/internal/carrier"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
type stubSession struct {
stream carrier.ByteStream
streamErr error
}
func (s *stubSession) Capabilities() carrier.Capabilities {
return carrier.Capabilities{ByteStream: true}
}
func (s *stubSession) OpenByteStream() (carrier.ByteStream, error) {
if s.streamErr != nil {
return nil, s.streamErr
}
return s.stream, nil
}
type nonByteStreamSession struct{}
func (s *nonByteStreamSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} }
type stubByteStream struct {
connectErr error
sendErr error
closeErr error
canSend bool
connectCalled bool
sent []byte
watched bool
reconnectCB func()
shouldFn func() bool
endedCB func(string)
}
func (s *stubByteStream) Connect(context.Context) error { s.connectCalled = true; return s.connectErr }
func (s *stubByteStream) Send(data []byte) error {
s.sent = append([]byte(nil), data...)
return s.sendErr
}
func (s *stubByteStream) Close() error { return s.closeErr }
func (s *stubByteStream) SetReconnectCallback(cb func()) { s.reconnectCB = cb }
func (s *stubByteStream) SetShouldReconnect(fn func() bool) { s.shouldFn = fn }
func (s *stubByteStream) SetEndedCallback(cb func(string)) { s.endedCB = cb }
func (s *stubByteStream) WatchConnection(context.Context) { s.watched = true }
func (s *stubByteStream) CanSend() bool { return s.canSend }
func TestNewAndFeatures(t *testing.T) {
stream := &stubByteStream{canSend: true}
carrier.Register("datachannel-test-new-and-features", func(context.Context, carrier.Config) (carrier.Session, error) {
return &stubSession{stream: stream}, nil
})
tr, err := New(context.Background(), transport.Config{Carrier: "datachannel-test-new-and-features"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if err := tr.Connect(context.Background()); err != nil {
t.Fatalf("Connect() error = %v", err)
}
if !stream.connectCalled {
t.Fatal("Connect() was not forwarded")
}
if err := tr.Send([]byte("payload")); err != nil {
t.Fatalf("Send() error = %v", err)
}
if string(stream.sent) != "payload" {
t.Fatalf("Send() forwarded %q, want payload", stream.sent)
}
tr.SetReconnectCallback(func() {})
tr.SetShouldReconnect(func() bool { return true })
tr.SetEndedCallback(func(string) {})
tr.WatchConnection(context.Background())
if stream.reconnectCB == nil || stream.shouldFn == nil || stream.endedCB == nil || !stream.watched {
t.Fatal("callbacks/watch were not forwarded")
}
if !tr.CanSend() {
t.Fatal("CanSend() = false, want true")
}
features := tr.Features()
if !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize != defaultMaxPayloadSize {
t.Fatalf("Features() = %+v", features)
}
if err := tr.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
}
func TestNewErrorPaths(t *testing.T) {
carrier.Register("datachannel-fail-create", func(context.Context, carrier.Config) (carrier.Session, error) {
return nil, errors.New("boom")
})
if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}); err == nil || err.Error() != "create provider transport: boom" {
t.Fatalf("New() error = %v", err)
}
carrier.Register("datachannel-no-stream", func(context.Context, carrier.Config) (carrier.Session, error) {
return &nonByteStreamSession{}, nil
})
if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}); !errors.Is(err, carrier.ErrByteStreamUnsupported) {
t.Fatalf("New() error = %v, want %v", err, carrier.ErrByteStreamUnsupported)
}
carrier.Register("datachannel-open-stream-fails", func(context.Context, carrier.Config) (carrier.Session, error) {
return &stubSession{streamErr: errors.New("open boom")}, nil
})
if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-open-stream-fails"}); err == nil || err.Error() != "open byte stream: open boom" {
t.Fatalf("New() error = %v", err)
}
}
func TestStreamTransportWrapsErrors(t *testing.T) {
tr := &streamTransport{stream: &stubByteStream{
connectErr: errors.New("connect boom"),
sendErr: errors.New("send boom"),
closeErr: errors.New("close boom"),
}}
if err := tr.Connect(context.Background()); err == nil || err.Error() != "stream connect: connect boom" {
t.Fatalf("Connect() error = %v", err)
}
if err := tr.Send([]byte("x")); err == nil || err.Error() != "stream send: send boom" {
t.Fatalf("Send() error = %v", err)
}
if err := tr.Close(); err == nil || err.Error() != "stream close: close boom" {
t.Fatalf("Close() error = %v", err)
}
}

View File

@@ -0,0 +1,84 @@
package seichannel
import (
"bytes"
"errors"
"testing"
)
func TestFragmentPayload(t *testing.T) {
frags := fragmentPayload([]byte("abcdef"), 2)
want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")}
if len(frags) != len(want) {
t.Fatalf("fragment count = %d, want %d", len(frags), len(want))
}
for i := range frags {
if !bytes.Equal(frags[i], want[i]) {
t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i])
}
}
empty := fragmentPayload(nil, 10)
if len(empty) != 1 || len(empty[0]) != 0 {
t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty)
}
}
func TestDecodeTransportFrameErrorsAndAck(t *testing.T) {
tests := []struct {
data []byte
want error
}{
{data: []byte{1, 2, 3}, want: ErrFrameTooShort},
{data: []byte{0, 0, 0, 0, protocolVersion, frameTypeAck}, want: ErrUnexpectedMagic},
{data: []byte{0x4f, 0x56, 0x43, 0x31, 9, frameTypeAck}, want: ErrUnexpectedVersion},
{data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, frameTypeAck}, want: ErrAckTooShort},
{data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, frameTypeData}, want: ErrDataTooShort},
{data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, 99}, want: ErrUnexpectedFrameType},
}
for _, tt := range tests {
if _, err := decodeTransportFrame(tt.data); !errors.Is(err, tt.want) {
t.Fatalf("decodeTransportFrame(%v) error = %v, want %v", tt.data, err, tt.want)
}
}
ack, err := decodeTransportFrame(encodeAckFrame(7, 0x1234))
if err != nil {
t.Fatalf("decode ack error = %v", err)
}
if ack.typ != frameTypeAck || ack.seq != 7 || ack.crc != 0x1234 {
t.Fatalf("ack = %+v", ack)
}
}
func TestSEIHelpersAndErrors(t *testing.T) {
escaped := escapeRBSP([]byte{0, 0, 1, 0, 0, 2, 3})
if !bytes.Equal(unescapeRBSP(escaped), []byte{0, 0, 1, 0, 0, 2, 3}) {
t.Fatalf("unescapeRBSP(escapeRBSP()) = %v", unescapeRBSP(escaped))
}
value := appendSEIValue(nil, 300)
got, next, err := consumeSEIValue(value, 0)
if err != nil || got != 300 || next != len(value) {
t.Fatalf("consumeSEIValue() = (%d, %d, %v), want 300", got, next, err)
}
if _, _, err := consumeSEIValue([]byte{0xff}, 0); !errors.Is(err, ErrSEIValueTruncated) {
t.Fatalf("consumeSEIValue() error = %v, want %v", err, ErrSEIValueTruncated)
}
rbsp := appendSEIValue(nil, 5)
rbsp = append(rbsp, appendSEIValue(nil, len(videoSEIUUID)+5)...)
rbsp = append(rbsp, videoSEIUUID[:]...)
rbsp = append(rbsp, []byte{1, 2}...)
if _, err := extractTransportSEI(rbsp); !errors.Is(err, ErrSEIPayloadTruncated) {
t.Fatalf("extractTransportSEI() error = %v, want %v", err, ErrSEIPayloadTruncated)
}
payloads, err := extractTransportSEI([]byte{4, 1, 0, 0x80})
if err != nil {
t.Fatalf("extractTransportSEI(non-transport) error = %v", err)
}
if len(payloads) != 0 {
t.Fatalf("extractTransportSEI(non-transport) = %v, want none", payloads)
}
}

View File

@@ -0,0 +1,111 @@
package seichannel
import (
"bytes"
"hash/crc32"
"testing"
)
func TestInboundAssemblyAndAck(t *testing.T) {
var got []byte
tr := &streamTransport{
onData: func(data []byte) { got = append([]byte(nil), data...) },
outboundAck: make(chan []byte, 4),
inbound: make(map[uint32]*inboundMessage),
delivered: make(map[uint32]uint32),
}
payload := []byte("hello world")
crc := crc32.ChecksumIEEE(payload)
tr.handleInboundFrame(transportFrame{
typ: frameTypeData,
seq: 1,
crc: crc,
totalLen: uint32(len(payload)),
fragIdx: 1,
fragTotal: 2,
payload: []byte(" world"),
})
if len(got) != 0 {
t.Fatalf("onData called before message complete: %q", got)
}
tr.handleInboundFrame(transportFrame{
typ: frameTypeData,
seq: 1,
crc: crc,
totalLen: uint32(len(payload)),
fragIdx: 0,
fragTotal: 2,
payload: []byte("hello"),
})
if !bytes.Equal(got, payload) {
t.Fatalf("assembled payload = %q, want %q", got, payload)
}
select {
case ack := <-tr.outboundAck:
frame, err := decodeTransportFrame(ack)
if err != nil || frame.typ != frameTypeAck || frame.seq != 1 || frame.crc != crc {
t.Fatalf("ack frame = %+v err=%v", frame, err)
}
default:
t.Fatal("handleInboundFrame() did not enqueue ack")
}
got = nil
tr.handleInboundFrame(transportFrame{
typ: frameTypeData,
seq: 1,
crc: crc,
totalLen: uint32(len(payload)),
fragIdx: 0,
fragTotal: 2,
payload: []byte("hello"),
})
if got != nil {
t.Fatalf("duplicate delivered payload again: %q", got)
}
}
func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) {
tr := &streamTransport{
outboundAck: make(chan []byte, 2),
inbound: make(map[uint32]*inboundMessage),
delivered: make(map[uint32]uint32),
}
msg, complete := tr.upsertInbound(transportFrame{
seq: 1,
crc: 1,
totalLen: 3,
fragIdx: 3,
fragTotal: 1,
payload: []byte("bad"),
})
if msg != nil || complete {
t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete)
}
called := false
tr.onData = func([]byte) { called = true }
tr.handleInboundFrame(transportFrame{
seq: 2,
crc: 123,
totalLen: 3,
fragIdx: 0,
fragTotal: 1,
payload: []byte("abc"),
})
if called {
t.Fatal("handleInboundFrame() delivered payload with bad crc")
}
msg = &inboundMessage{
totalLen: 3,
crc: crc32.ChecksumIEEE([]byte("abcdef")),
frags: [][]byte{[]byte("abc"), []byte("def")},
}
if got := tr.assembleMessage(msg); string(got) != "abc" {
t.Fatalf("assembleMessage() = %q, want abc", got)
}
}

View File

@@ -0,0 +1,72 @@
package transport
import (
"context"
"errors"
"reflect"
"testing"
)
type stubTransport struct{}
func (s *stubTransport) Connect(context.Context) error { return nil }
func (s *stubTransport) Send([]byte) error { return nil }
func (s *stubTransport) Close() error { return nil }
func (s *stubTransport) SetReconnectCallback(func()) {}
func (s *stubTransport) SetShouldReconnect(func() bool) {}
func (s *stubTransport) SetEndedCallback(func(string)) {}
func (s *stubTransport) WatchConnection(context.Context) {}
func (s *stubTransport) CanSend() bool { return true }
func (s *stubTransport) Features() Features { return Features{Reliable: true} }
func snapshotTransportRegistry() map[string]Factory {
out := make(map[string]Factory, len(registry))
for k, v := range registry {
out[k] = v
}
return out
}
func restoreTransportRegistry(src map[string]Factory) {
registry = make(map[string]Factory, len(src))
for k, v := range src {
registry[k] = v
}
}
func TestNewAndAvailable(t *testing.T) {
old := snapshotTransportRegistry()
t.Cleanup(func() { restoreTransportRegistry(old) })
called := false
Register("test-transport", func(_ context.Context, cfg Config) (Transport, error) {
called = cfg.ClientID == "client-1"
return &stubTransport{}, nil
})
got, err := New(context.Background(), "test-transport", Config{ClientID: "client-1"})
if err != nil {
t.Fatalf("New() error = %v", err)
}
if !called {
t.Fatal("factory did not receive config")
}
if _, ok := got.(*stubTransport); !ok {
t.Fatalf("New() returned %T, want *stubTransport", got)
}
if !reflect.DeepEqual(Available(), []string{"test-transport"}) {
t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-transport"})
}
}
func TestNewReturnsErrTransportNotFound(t *testing.T) {
old := snapshotTransportRegistry()
t.Cleanup(func() { restoreTransportRegistry(old) })
registry = map[string]Factory{}
_, err := New(context.Background(), "missing", Config{})
if !errors.Is(err, ErrTransportNotFound) {
t.Fatalf("New() error = %v, want %v", err, ErrTransportNotFound)
}
}

View File

@@ -0,0 +1,139 @@
package videochannel
import (
"bytes"
"errors"
"io"
"slices"
"strings"
"testing"
"github.com/pion/webrtc/v4"
)
func TestFragmentPayload(t *testing.T) {
frags := fragmentPayload([]byte("abcdef"), 2)
want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")}
if len(frags) != len(want) {
t.Fatalf("fragment count = %d, want %d", len(frags), len(want))
}
for i := range frags {
if !bytes.Equal(frags[i], want[i]) {
t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i])
}
}
empty := fragmentPayload(nil, 10)
if len(empty) != 1 || len(empty[0]) != 0 {
t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty)
}
}
func TestDecodeTransportFrameErrorsAndAck(t *testing.T) {
tests := []struct {
data []byte
want error
}{
{data: []byte{1, 2, 3}, want: ErrFrameTooShort},
{data: []byte{0, 0, 0, 0, protocolVersion, frameTypeAck}, want: ErrUnexpectedMagic},
{data: []byte{0x4f, 0x56, 0x56, 0x32, 9, frameTypeAck}, want: ErrUnexpectedVersion},
{data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, frameTypeAck}, want: ErrAckTooShort},
{data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, frameTypeData}, want: ErrDataTooShort},
{data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, 99}, want: ErrUnexpectedFrameType},
}
for _, tt := range tests {
if _, err := decodeTransportFrame(tt.data); !errors.Is(err, tt.want) {
t.Fatalf("decodeTransportFrame(%v) error = %v, want %v", tt.data, err, tt.want)
}
}
ack, err := decodeTransportFrame(encodeAckFrame(7, 0x1234))
if err != nil {
t.Fatalf("decode ack error = %v", err)
}
if ack.typ != frameTypeAck || ack.seq != 7 || ack.crc != 0x1234 {
t.Fatalf("ack = %+v", ack)
}
}
func TestCodecSpecsAndArgs(t *testing.T) {
for _, mime := range []string{webrtc.MimeTypeH264, webrtc.MimeTypeVP8, webrtc.MimeTypeVP9} {
spec, ok := codecSpecForMime(mime)
if !ok {
t.Fatalf("codecSpecForMime(%q) ok = false", mime)
}
if spec.mimeType != mime || spec.depacketizer == nil || spec.capability.ClockRate != 90000 {
t.Fatalf("codec spec = %+v", spec)
}
}
if _, ok := codecSpecForMime("video/unknown"); ok {
t.Fatal("codecSpecForMime() accepted unknown mime")
}
if got := resolveEncoderCodec(h264CodecSpec(), "nvenc"); got != "h264_nvenc" {
t.Fatalf("resolveEncoderCodec(h264,nvenc) = %q", got)
}
if got := resolveEncoderCodec(vp8CodecSpec(), "none"); got != "libvpx" {
t.Fatalf("resolveEncoderCodec(vp8,none) = %q", got)
}
args := buildEncoderArgs(vp8CodecSpec(), "vp8_nvenc", 320, 240, 30, "1M")
for _, want := range []string{"-video_size", "320x240", "-framerate", "30", "vp8_nvenc", "-b:v", "1M", "ivf"} {
if !slices.Contains(args, want) {
t.Fatalf("buildEncoderArgs() = %v, missing %q", args, want)
}
}
h264Args := buildEncoderArgs(h264CodecSpec(), "libx264", 320, 240, 30, "1M")
if h264Args[len(h264Args)-2] != "h264" {
t.Fatalf("h264 encoder args = %v", h264Args)
}
}
type shortWriter struct {
writes int
}
func (w *shortWriter) Write(p []byte) (int, error) {
w.writes++
if w.writes == 1 {
return 1, nil
}
return len(p), nil
}
type errWriter struct{}
func (w errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe }
func TestIVFWritersAndWithStderr(t *testing.T) {
var buf bytes.Buffer
if err := writeIVFHeader(&buf, "VP80", 320, 240, 30); err != nil {
t.Fatalf("writeIVFHeader() error = %v", err)
}
if buf.Len() != 32 || string(buf.Bytes()[:4]) != "DKIF" {
t.Fatalf("IVF header = %v", buf.Bytes())
}
buf.Reset()
if err := writeIVFFrame(&buf, 3, []byte("abc")); err != nil {
t.Fatalf("writeIVFFrame() error = %v", err)
}
if buf.Len() != 15 {
t.Fatalf("IVF frame len = %d, want 15", buf.Len())
}
if err := writeAll(&shortWriter{}, []byte("abc")); err != nil {
t.Fatalf("writeAll(shortWriter) error = %v", err)
}
if err := writeAll(errWriter{}, []byte("abc")); err == nil || !strings.Contains(err.Error(), "write:") {
t.Fatalf("writeAll(errWriter) error = %v", err)
}
baseErr := errors.New("base")
if got := withStderr(baseErr, bytes.NewBufferString(" details \n")); got == nil || got.Error() != "base: details" {
t.Fatalf("withStderr() = %v", got)
}
if got := withStderr(nil, bytes.NewBufferString("details")); got != nil {
t.Fatalf("withStderr(nil) = %v", got)
}
}

View File

@@ -0,0 +1,97 @@
package videochannel
import (
"bytes"
"hash/crc32"
"testing"
)
func TestInboundAssemblyAndAck(t *testing.T) {
var got []byte
tr := &streamTransport{
onData: func(data []byte) { got = append([]byte(nil), data...) },
outboundAck: make(chan []byte, 4),
inbound: make(map[uint32]*inboundMessage),
delivered: make(map[uint32]uint32),
}
payload := []byte("hello video")
crc := crc32.ChecksumIEEE(payload)
tr.handleInboundFrame(transportFrame{
typ: frameTypeData,
seq: 1,
crc: crc,
totalLen: uint32(len(payload)),
fragIdx: 1,
fragTotal: 2,
payload: []byte(" video"),
})
if len(got) != 0 {
t.Fatalf("onData called before message complete: %q", got)
}
tr.handleInboundFrame(transportFrame{
typ: frameTypeData,
seq: 1,
crc: crc,
totalLen: uint32(len(payload)),
fragIdx: 0,
fragTotal: 2,
payload: []byte("hello"),
})
if !bytes.Equal(got, payload) {
t.Fatalf("assembled payload = %q, want %q", got, payload)
}
select {
case ack := <-tr.outboundAck:
frame, err := decodeTransportFrame(ack)
if err != nil || frame.typ != frameTypeAck || frame.seq != 1 || frame.crc != crc {
t.Fatalf("ack frame = %+v err=%v", frame, err)
}
default:
t.Fatal("handleInboundFrame() did not enqueue ack")
}
}
func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) {
tr := &streamTransport{
outboundAck: make(chan []byte, 2),
inbound: make(map[uint32]*inboundMessage),
delivered: make(map[uint32]uint32),
}
msg, complete := tr.upsertInbound(transportFrame{
seq: 1,
crc: 1,
totalLen: 3,
fragIdx: 3,
fragTotal: 1,
payload: []byte("bad"),
})
if msg != nil || complete {
t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete)
}
called := false
tr.onData = func([]byte) { called = true }
tr.handleInboundFrame(transportFrame{
seq: 2,
crc: 123,
totalLen: 3,
fragIdx: 0,
fragTotal: 1,
payload: []byte("abc"),
})
if called {
t.Fatal("handleInboundFrame() delivered payload with bad crc")
}
msg = &inboundMessage{
totalLen: 3,
crc: crc32.ChecksumIEEE([]byte("abcdef")),
frags: [][]byte{[]byte("abc"), []byte("def")},
}
if got := tr.assembleMessage(msg); string(got) != "abc" {
t.Fatalf("assembleMessage() = %q, want abc", got)
}
}

View File

@@ -0,0 +1,71 @@
package vp8channel
import (
"bytes"
"errors"
"net"
"testing"
"time"
)
func TestKCPConnReadWriteDeadlinesAndClose(t *testing.T) {
out := make(chan []byte, 1)
hdr := testEpochHdr(9)
conn := newKCPConn(out, 1, hdr)
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetDeadline() error = %v", err)
}
if conn.LocalAddr().String() != "127.0.0.1:1" {
t.Fatalf("LocalAddr() = %v", conn.LocalAddr())
}
n, err := conn.WriteTo([]byte("payload"), nil)
if err != nil || n != len("payload") {
t.Fatalf("WriteTo() = (%d, %v), want payload length", n, err)
}
wire := <-out
if !bytes.Equal(wire[:epochHdrLen], hdr[:]) || string(wire[epochHdrLen:]) != "payload" {
t.Fatalf("wire packet = %v", wire)
}
conn.deliver([]byte("incoming"))
buf := make([]byte, 64)
n, addr, err := conn.ReadFrom(buf)
if err != nil || addr == nil || string(buf[:n]) != "incoming" {
t.Fatalf("ReadFrom() = (%d, %v, %v), payload %q", n, addr, err, buf[:n])
}
if err := conn.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
if _, _, err := conn.ReadFrom(buf); !errors.Is(err, net.ErrClosed) {
t.Fatalf("ReadFrom() error = %v, want net.ErrClosed", err)
}
closedWrite := newKCPConn(make(chan []byte), 1, hdr)
_ = closedWrite.Close()
if _, err := closedWrite.WriteTo([]byte("x"), nil); !errors.Is(err, net.ErrClosed) {
t.Fatalf("WriteTo() error = %v, want net.ErrClosed", err)
}
}
func TestKCPConnTimeouts(t *testing.T) {
conn := newKCPConn(make(chan []byte), 1, testEpochHdr(1))
if err := conn.SetReadDeadline(time.Now().Add(-time.Millisecond)); err != nil {
t.Fatalf("SetReadDeadline() error = %v", err)
}
buf := make([]byte, 4)
if _, _, err := conn.ReadFrom(buf); err == nil {
t.Fatal("ReadFrom() unexpectedly succeeded")
} else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() || !netErr.Temporary() {
t.Fatalf("ReadFrom() error = %T %v, want timeout net.Error", err, err)
}
if err := conn.SetWriteDeadline(time.Now().Add(-time.Millisecond)); err != nil {
t.Fatalf("SetWriteDeadline() error = %v", err)
}
if _, err := conn.WriteTo([]byte("x"), nil); err == nil {
t.Fatal("WriteTo() unexpectedly succeeded")
}
}

206
mobile/mobile_test.go Normal file
View File

@@ -0,0 +1,206 @@
package mobile
import (
"errors"
"log"
"strings"
"sync"
"testing"
"time"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/protect"
)
type testProtector struct {
called int
}
func (p *testProtector) Protect(fd int) bool {
p.called = fd
return true
}
type testLogWriter struct {
got string
}
func (w *testLogWriter) WriteLog(msg string) {
w.got += msg
}
func resetMobileGlobals(t *testing.T) {
t.Helper()
mu.Lock()
if cancel != nil {
cancel()
}
cancel = nil
done = nil
ready = nil
errRun = nil
defaults = mobileConfig{}
defaultsSet = sync.Once{}
mu.Unlock()
protect.Protector = nil
logger.SetVerbose(false)
}
func TestProtectorAndLogging(t *testing.T) {
resetMobileGlobals(t)
p := &testProtector{}
SetProtector(p)
if protect.Protector == nil || !protect.Protector(123) || p.called != 123 {
t.Fatal("SetProtector() did not install adapter")
}
SetProtector(nil)
if protect.Protector != nil {
t.Fatal("SetProtector(nil) did not clear protector")
}
w := &testLogWriter{}
SetLogWriter(w)
log.Print("hello")
if !strings.Contains(w.got, "hello") {
t.Fatalf("log writer got %q, want hello", w.got)
}
}
func TestDefaultsAndSetters(t *testing.T) {
resetMobileGlobals(t)
SetTransport("dc")
SetLink("direct")
SetDNS("9.9.9.9:53")
SetVP8Options(-1, 999)
mu.Lock()
got := defaults
mu.Unlock()
if got.transport != dataTransport || got.link != defaultLink || got.dnsServer != "9.9.9.9:53" ||
got.vp8FPS != 1 || got.vp8BatchSize != 64 {
t.Fatalf("defaults = %+v", got)
}
SetDebug(true)
if !logger.IsVerbose() {
t.Fatal("SetDebug(true) did not enable verbose")
}
SetDebug(false)
if logger.IsVerbose() {
t.Fatal("SetDebug(false) did not disable verbose")
}
}
func TestNormalizeBuildRoomAndClamp(t *testing.T) {
tests := map[string]string{
"datachannel": dataTransport,
"data": dataTransport,
"dc": dataTransport,
"vp8channel": defaultTransport,
"vp8": defaultTransport,
"bad": defaultTransport,
}
for in, want := range tests {
if got := normalizeTransport(in); got != want {
t.Fatalf("normalizeTransport(%q) = %q, want %q", in, got, want)
}
}
if normalizeCarrier(carrierWBStream) != carrierWBStream || normalizeCarrier("jazz") != "jazz" {
t.Fatal("normalizeCarrier() returned unexpected value")
}
if got := buildRoomURL("telemost", "abc"); got != "https://telemost.yandex.ru/j/abc" {
t.Fatalf("telemost room URL = %q", got)
}
if got := buildRoomURL("jazz", ""); got != "any" {
t.Fatalf("jazz empty room URL = %q", got)
}
if got := buildRoomURL(carrierWBStream, "room"); got != "room" {
t.Fatalf("wbstream room URL = %q", got)
}
if clamp(0, 1, 10) != 1 || clamp(11, 1, 10) != 10 || clamp(5, 1, 10) != 5 {
t.Fatal("clamp() returned unexpected value")
}
}
func TestStartValidation(t *testing.T) {
resetMobileGlobals(t)
if err := startWithConfig("", dataTransport, "room", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errCarrierRequired) {
t.Fatalf("startWithConfig(missing carrier) = %v", err)
}
if err := startWithConfig("telemost", dataTransport, "", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errRoomIDRequired) {
t.Fatalf("startWithConfig(missing room) = %v", err)
}
if err := startWithConfig("jazz", dataTransport, "", "", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errClientIDRequired) {
t.Fatalf("startWithConfig(missing client) = %v", err)
}
if err := startWithConfig("jazz", dataTransport, "", "client", "", 1080, "", "", mobileConfig{}); !errors.Is(err, errKeyHexRequired) {
t.Fatalf("startWithConfig(missing key) = %v", err)
}
mu.Lock()
cancel = func() {}
mu.Unlock()
if err := startWithConfig("jazz", dataTransport, "", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errAlreadyRunning) {
t.Fatalf("startWithConfig(running) = %v", err)
}
resetMobileGlobals(t)
}
func TestWaitReadyStatesAndStop(t *testing.T) {
resetMobileGlobals(t)
if err := WaitReady(1); !errors.Is(err, errNotRunning) {
t.Fatalf("WaitReady(not running) = %v", err)
}
mu.Lock()
errRun = errors.New("run failed")
mu.Unlock()
if err := WaitReady(1); err == nil || err.Error() != "run failed" {
t.Fatalf("WaitReady(run err) = %v", err)
}
mu.Lock()
errRun = nil
ready = make(chan struct{})
done = make(chan struct{})
cancel = func() {}
mu.Unlock()
if err := WaitReady(1); !errors.Is(err, errStartTimedOut) {
t.Fatalf("WaitReady(timeout) = %v", err)
}
mu.Lock()
close(ready)
mu.Unlock()
if err := WaitReady(1); err != nil {
t.Fatalf("WaitReady(ready) error = %v", err)
}
mu.Lock()
cancel = func() {}
done = make(chan struct{})
doneCh := done
mu.Unlock()
go func() {
time.Sleep(time.Millisecond)
close(doneCh)
}()
Stop()
mu.Lock()
cancel = nil
mu.Unlock()
}
func TestLogBridge(t *testing.T) {
w := &testLogWriter{}
n, err := (&logBridge{w: w}).Write([]byte("abc"))
if err != nil || n != 3 || w.got != "abc" {
t.Fatalf("logBridge.Write() = (%d, %v), got %q", n, err, w.got)
}
}