diff --git a/.gitignore b/.gitignore index 78a67e1..1e74a6c 100644 --- a/.gitignore +++ b/.gitignore @@ -247,3 +247,5 @@ build/ GEMINI.md code/package-lock.json olcrtc +!cmd/olcrtc/ +!cmd/olcrtc/main_test.go diff --git a/cmd/olcrtc/main_test.go b/cmd/olcrtc/main_test.go new file mode 100644 index 0000000..8467992 --- /dev/null +++ b/cmd/olcrtc/main_test.go @@ -0,0 +1,137 @@ +package main + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" + "github.com/openlibrecommunity/olcrtc/internal/logger" +) + +func TestToSessionConfigAndFirstNonEmpty(t *testing.T) { + cfg := config{ + mode: "cnc", + link: "direct", + transport: "vp8channel", + provider: "jazz", + roomID: "room", + clientID: "client", + keyHex: "key", + socksHost: "127.0.0.1", + socksPort: 1080, + dnsServer: "1.1.1.1:53", + socksProxyAddr: "proxy", + socksProxyPort: 1081, + videoWidth: 640, + videoHeight: 480, + videoFPS: 30, + videoBitrate: "1M", + videoHW: "none", + videoQRSize: 4, + videoQRRecovery: "low", + videoCodec: "qrcode", + videoTileModule: 4, + videoTileRS: 20, + vp8FPS: 25, + vp8BatchSize: 8, + } + + got := toSessionConfig(cfg) + if got.Mode != cfg.mode || got.Carrier != "jazz" || got.SOCKSPort != cfg.socksPort || + got.VideoTileRS != cfg.videoTileRS || got.VP8BatchSize != cfg.vp8BatchSize { + t.Fatalf("toSessionConfig() = %+v", got) + } + + cfg.carrier = "telemost" + got = toSessionConfig(cfg) + if got.Carrier != "telemost" { + t.Fatalf("carrier precedence = %q, want telemost", got.Carrier) + } + + if got := firstNonEmpty("", "", "x", "y"); got != "x" { + t.Fatalf("firstNonEmpty() = %q, want x", got) + } + if got := firstNonEmpty("", ""); got != "" { + t.Fatalf("firstNonEmpty(empty) = %q, want empty", got) + } +} + +func TestConfigureLogging(t *testing.T) { + logger.SetVerbose(false) + configureLogging(true) + if !logger.IsVerbose() { + t.Fatal("configureLogging(true) did not enable verbose logging") + } + + logger.SetVerbose(false) + configureLogging(false) + if logger.IsVerbose() { + t.Fatal("configureLogging(false) enabled verbose logging") + } +} + +func TestResolveDataDir(t *testing.T) { + abs := filepath.Join(t.TempDir(), "data") + got, err := resolveDataDir(abs) + if err != nil { + t.Fatalf("resolveDataDir(abs) error = %v", err) + } + if got != abs { + t.Fatalf("resolveDataDir(abs) = %q, want %q", got, abs) + } + + got, err = resolveDataDir("data") + if err != nil { + t.Fatalf("resolveDataDir(rel) error = %v", err) + } + if filepath.Base(got) != "data" || !filepath.IsAbs(got) { + t.Fatalf("resolveDataDir(rel) = %q, want absolute path ending in data", got) + } +} + +func TestLoadNames(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "names"), []byte("A\n"), 0o600); err != nil { + t.Fatalf("WriteFile(names) error = %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "surnames"), []byte("B\n"), 0o600); err != nil { + t.Fatalf("WriteFile(surnames) error = %v", err) + } + if err := loadNames(dir); err != nil { + t.Fatalf("loadNames() error = %v", err) + } +} + +func TestWaitForShutdown(t *testing.T) { + errCh := make(chan error, 1) + errCh <- nil + if err := waitForShutdown(errCh); err != nil { + t.Fatalf("waitForShutdown(nil) error = %v", err) + } + + want := errors.New("boom") + errCh = make(chan error, 1) + errCh <- want + if err := waitForShutdown(errCh); !errors.Is(err, want) { + t.Fatalf("waitForShutdown(error) = %v, want %v", err, want) + } +} + +func TestValidateConfigAliasStillValidates(t *testing.T) { + session.RegisterDefaults() + cfg := config{ + mode: "srv", + link: "direct", + transport: "datachannel", + provider: "jazz", + clientID: "client", + keyHex: "key", + dnsServer: "1.1.1.1:53", + videoCodec: "qrcode", + } + if err := session.Validate(toSessionConfig(cfg)); err != nil { + t.Fatalf("Validate(toSessionConfig(alias)) error = %v", err) + } +} diff --git a/internal/app/session/session_test.go b/internal/app/session/session_test.go new file mode 100644 index 0000000..5cf46a4 --- /dev/null +++ b/internal/app/session/session_test.go @@ -0,0 +1,303 @@ +package session + +import ( + "errors" + "testing" +) + +func TestValidate(t *testing.T) { + RegisterDefaults() + + base := Config{ + Mode: modeSRV, + Link: "direct", + Transport: "datachannel", + Carrier: "telemost", + RoomID: "room-1", + ClientID: "client-1", + KeyHex: "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff", + DNSServer: "1.1.1.1:53", + } + + tests := []struct { + name string + cfg Config + want error + }{ + {name: "valid baseline", cfg: base}, + { + name: "jazz allows empty room id", + cfg: func() Config { + cfg := base + cfg.Carrier = "jazz" + cfg.RoomID = "" + return cfg + }(), + }, + { + name: "cnc requires socks host and port", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "127.0.0.1" + cfg.SOCKSPort = 1080 + return cfg + }(), + }, + { + name: "missing mode", + cfg: func() Config { + cfg := base + cfg.Mode = "" + return cfg + }(), + want: ErrModeRequired, + }, + { + name: "unsupported carrier", + cfg: func() Config { + cfg := base + cfg.Carrier = "unknown" + return cfg + }(), + want: ErrUnsupportedCarrier, + }, + { + name: "unsupported link", + cfg: func() Config { + cfg := base + cfg.Link = "unknown" + return cfg + }(), + want: ErrUnsupportedLink, + }, + { + name: "unsupported transport", + cfg: func() Config { + cfg := base + cfg.Transport = "unknown" + return cfg + }(), + want: ErrUnsupportedTransport, + }, + { + name: "room id required for non jazz", + cfg: func() Config { + cfg := base + cfg.RoomID = "" + return cfg + }(), + want: ErrRoomIDRequired, + }, + { + name: "client id required", + cfg: func() Config { + cfg := base + cfg.ClientID = "" + return cfg + }(), + want: ErrClientIDRequired, + }, + { + name: "key required", + cfg: func() Config { + cfg := base + cfg.KeyHex = "" + return cfg + }(), + want: ErrKeyRequired, + }, + { + name: "dns server required", + cfg: func() Config { + cfg := base + cfg.DNSServer = "" + return cfg + }(), + want: ErrDNSServerRequired, + }, + { + name: "videochannel requires dimensions and bitrate settings", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + return cfg + }(), + want: ErrVideoWidthRequired, + }, + { + name: "videochannel rejects invalid codec", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + cfg.VideoHeight = 480 + cfg.VideoFPS = 30 + cfg.VideoBitrate = "1M" + cfg.VideoHW = "none" + cfg.VideoCodec = "bogus" + return cfg + }(), + want: ErrVideoCodecInvalid, + }, + { + name: "videochannel requires height", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + return cfg + }(), + want: ErrVideoHeightRequired, + }, + { + name: "videochannel requires fps", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + cfg.VideoHeight = 480 + return cfg + }(), + want: ErrVideoFPSRequired, + }, + { + name: "videochannel requires bitrate", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + cfg.VideoHeight = 480 + cfg.VideoFPS = 30 + return cfg + }(), + want: ErrVideoBitrateRequired, + }, + { + name: "videochannel requires hw", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + cfg.VideoHeight = 480 + cfg.VideoFPS = 30 + cfg.VideoBitrate = "1M" + return cfg + }(), + want: ErrVideoHWRequired, + }, + { + name: "tile codec requires square 1080 dimensions", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 640 + cfg.VideoHeight = 480 + cfg.VideoFPS = 30 + cfg.VideoBitrate = "1M" + cfg.VideoHW = "none" + cfg.VideoCodec = "tile" + return cfg + }(), + want: ErrTileCodecDimensions, + }, + { + name: "videochannel valid", + cfg: func() Config { + cfg := base + cfg.Transport = "videochannel" + cfg.VideoWidth = 1080 + cfg.VideoHeight = 1080 + cfg.VideoFPS = 30 + cfg.VideoBitrate = "1M" + cfg.VideoHW = "none" + cfg.VideoCodec = "tile" + return cfg + }(), + }, + { + name: "vp8channel requires fps", + cfg: func() Config { + cfg := base + cfg.Transport = "vp8channel" + return cfg + }(), + want: ErrVP8FPSRequired, + }, + { + name: "vp8channel requires batch size", + cfg: func() Config { + cfg := base + cfg.Transport = "vp8channel" + cfg.VP8FPS = 25 + return cfg + }(), + want: ErrVP8BatchSizeRequired, + }, + { + name: "vp8channel valid", + cfg: func() Config { + cfg := base + cfg.Transport = "vp8channel" + cfg.VP8FPS = 25 + cfg.VP8BatchSize = 16 + return cfg + }(), + }, + { + name: "cnc requires socks host", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSPort = 1080 + return cfg + }(), + want: ErrSOCKSHostRequired, + }, + { + name: "cnc requires socks port", + cfg: func() Config { + cfg := base + cfg.Mode = modeCNC + cfg.SOCKSHost = "127.0.0.1" + return cfg + }(), + want: ErrSOCKSPortRequired, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Validate(tt.cfg) + if tt.want == nil { + if err != nil { + t.Fatalf("Validate() error = %v", err) + } + return + } + if !errors.Is(err, tt.want) { + t.Fatalf("Validate() error = %v, want %v", err, tt.want) + } + }) + } +} + +func TestBuildRoomURL(t *testing.T) { + tests := []struct { + carrier string + roomID string + want string + }{ + {carrier: "telemost", roomID: "abc", want: "https://telemost.yandex.ru/j/abc"}, + {carrier: "jazz", roomID: "", want: "any"}, + {carrier: "jazz", roomID: "room", want: "room"}, + {carrier: "wbstream", roomID: "wb", want: "wb"}, + {carrier: "other", roomID: "raw", want: "raw"}, + } + + for _, tt := range tests { + if got := buildRoomURL(tt.carrier, tt.roomID); got != tt.want { + t.Fatalf("buildRoomURL(%q, %q) = %q, want %q", tt.carrier, tt.roomID, got, tt.want) + } + } +} diff --git a/internal/carrier/builtin/register_test.go b/internal/carrier/builtin/register_test.go new file mode 100644 index 0000000..633d8d3 --- /dev/null +++ b/internal/carrier/builtin/register_test.go @@ -0,0 +1,18 @@ +package builtin + +import ( + "slices" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/carrier" +) + +func TestRegister(t *testing.T) { + Register() + available := carrier.Available() + for _, want := range []string{"jazz", "telemost", "wbstream"} { + if !slices.Contains(available, want) { + t.Fatalf("Available() = %v, missing %q", available, want) + } + } +} diff --git a/internal/carrier/carrier_test.go b/internal/carrier/carrier_test.go new file mode 100644 index 0000000..6299ba5 --- /dev/null +++ b/internal/carrier/carrier_test.go @@ -0,0 +1,251 @@ +package carrier + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/pion/webrtc/v4" +) + +type stubProvider struct { + connectErr error + sendErr error + closeErr error + canSend bool + reconnectCallback func(*webrtc.DataChannel) + shouldReconnect func() bool + endedCallback func(string) + watchCalled bool + addTrackErr error + trackHandlerCalled bool +} + +func (s *stubProvider) Connect(context.Context) error { return s.connectErr } +func (s *stubProvider) Send([]byte) error { return s.sendErr } +func (s *stubProvider) Close() error { return s.closeErr } +func (s *stubProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.reconnectCallback = cb } +func (s *stubProvider) SetShouldReconnect(fn func() bool) { s.shouldReconnect = fn } +func (s *stubProvider) SetEndedCallback(cb func(string)) { s.endedCallback = cb } +func (s *stubProvider) WatchConnection(context.Context) { s.watchCalled = true } +func (s *stubProvider) CanSend() bool { return s.canSend } +func (s *stubProvider) GetSendQueue() chan []byte { return nil } +func (s *stubProvider) GetBufferedAmount() uint64 { return 0 } +func (s *stubProvider) AddVideoTrack(webrtc.TrackLocal) error { return s.addTrackErr } +func (s *stubProvider) SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + s.trackHandlerCalled = true +} + +type plainProvider struct { + connectErr error + sendErr error + closeErr error + canSend bool + reconnectCallback func(*webrtc.DataChannel) + shouldReconnect func() bool + endedCallback func(string) + watchCalled bool +} + +func (p *plainProvider) Connect(context.Context) error { return p.connectErr } +func (p *plainProvider) Send([]byte) error { return p.sendErr } +func (p *plainProvider) Close() error { return p.closeErr } +func (p *plainProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) { p.reconnectCallback = cb } +func (p *plainProvider) SetShouldReconnect(fn func() bool) { p.shouldReconnect = fn } +func (p *plainProvider) SetEndedCallback(cb func(string)) { p.endedCallback = cb } +func (p *plainProvider) WatchConnection(context.Context) { p.watchCalled = true } +func (p *plainProvider) CanSend() bool { return p.canSend } +func (p *plainProvider) GetSendQueue() chan []byte { return nil } +func (p *plainProvider) GetBufferedAmount() uint64 { return 0 } + +func snapshotCarrierRegistry() map[string]Factory { + out := make(map[string]Factory, len(registry)) + for k, v := range registry { + out[k] = v + } + return out +} + +func restoreCarrierRegistry(src map[string]Factory) { + registry = make(map[string]Factory, len(src)) + for k, v := range src { + registry[k] = v + } +} + +func TestRegisterLegacyAndAvailable(t *testing.T) { + old := snapshotCarrierRegistry() + t.Cleanup(func() { restoreCarrierRegistry(old) }) + + RegisterLegacy("legacy-test", func(_ context.Context, cfg provider.Config) (provider.Provider, error) { + if cfg.Name != "peer" { + t.Fatalf("provider config name = %q, want peer", cfg.Name) + } + return &stubProvider{canSend: true}, nil + }) + + sess, err := New(context.Background(), "legacy-test", Config{Name: "peer"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + caps := sess.Capabilities() + if !caps.ByteStream || !caps.VideoTrack { + t.Fatalf("Capabilities() = %+v, want byte and video true", caps) + } + + if !reflect.DeepEqual(Available(), []string{"legacy-test"}) { + t.Fatalf("Available() = %#v, want %#v", Available(), []string{"legacy-test"}) + } +} + +func TestNewReturnsErrCarrierNotFound(t *testing.T) { + old := snapshotCarrierRegistry() + t.Cleanup(func() { restoreCarrierRegistry(old) }) + registry = map[string]Factory{} + + _, err := New(context.Background(), "missing", Config{}) + if !errors.Is(err, ErrCarrierNotFound) { + t.Fatalf("New() error = %v, want %v", err, ErrCarrierNotFound) + } +} + +func TestLegacySessionOpenVideoTrackUnsupported(t *testing.T) { + sess := &legacySession{provider: &plainProvider{}} + + caps := sess.Capabilities() + if !caps.ByteStream || caps.VideoTrack { + t.Fatalf("Capabilities() = %+v, want byte true and video false", caps) + } + + _, err := sess.OpenVideoTrack() + if !errors.Is(err, ErrVideoTrackUnsupported) { + t.Fatalf("OpenVideoTrack() error = %v, want %v", err, ErrVideoTrackUnsupported) + } +} + +func TestLegacyByteStreamWrapsProviderAndCallbacks(t *testing.T) { + prov := &stubProvider{canSend: true} + stream := &legacyByteStream{provider: prov} + + called := false + stream.SetReconnectCallback(func() { called = true }) + if prov.reconnectCallback == nil { + t.Fatal("SetReconnectCallback() did not install provider callback") + } + prov.reconnectCallback(nil) + if !called { + t.Fatal("reconnect callback was not adapted") + } + + reconnectAllowed := false + stream.SetShouldReconnect(func() bool { reconnectAllowed = true; return true }) + if prov.shouldReconnect == nil || !prov.shouldReconnect() || !reconnectAllowed { + t.Fatal("SetShouldReconnect() was not forwarded") + } + + ended := "" + stream.SetEndedCallback(func(reason string) { ended = reason }) + if prov.endedCallback == nil { + t.Fatal("SetEndedCallback() was not forwarded") + } + prov.endedCallback("bye") + if ended != "bye" { + t.Fatalf("ended callback reason = %q, want bye", ended) + } + + stream.WatchConnection(context.Background()) + if !prov.watchCalled { + t.Fatal("WatchConnection() was not forwarded") + } + if !stream.CanSend() { + t.Fatal("CanSend() = false, want true") + } +} + +func TestLegacyByteStreamWrapsErrors(t *testing.T) { + prov := &stubProvider{ + connectErr: errors.New("connect boom"), + sendErr: errors.New("send boom"), + closeErr: errors.New("close boom"), + } + stream := &legacyByteStream{provider: prov} + + if err := stream.Connect(context.Background()); err == nil || err.Error() != "connect: connect boom" { + t.Fatalf("Connect() error = %v", err) + } + if err := stream.Send([]byte("x")); err == nil || err.Error() != "send: send boom" { + t.Fatalf("Send() error = %v", err) + } + if err := stream.Close(); err == nil || err.Error() != "close: close boom" { + t.Fatalf("Close() error = %v", err) + } +} + +func TestLegacySessionOpenByteStreamAndVideoTrack(t *testing.T) { + prov := &stubProvider{canSend: true} + sess := &legacySession{provider: prov} + + stream, err := sess.OpenByteStream() + if err != nil { + t.Fatalf("OpenByteStream() error = %v", err) + } + if !stream.CanSend() { + t.Fatal("byte stream CanSend() = false, want true") + } + + video, err := sess.OpenVideoTrack() + if err != nil { + t.Fatalf("OpenVideoTrack() error = %v", err) + } + if err := video.Connect(context.Background()); err != nil { + t.Fatalf("video Connect() error = %v", err) + } + if err := video.Close(); err != nil { + t.Fatalf("video Close() error = %v", err) + } + video.SetShouldReconnect(func() bool { return true }) + video.SetEndedCallback(func(string) {}) + video.WatchConnection(context.Background()) + if !video.CanSend() || prov.shouldReconnect == nil || prov.endedCallback == nil || !prov.watchCalled { + t.Fatal("video adapter did not forward calls") + } +} + +func TestLegacyVideoTrackWrapsOperations(t *testing.T) { + prov := &stubProvider{canSend: true, addTrackErr: errors.New("track boom")} + track := &legacyVideoTrack{provider: prov} + + called := false + track.SetReconnectCallback(func() { called = true }) + prov.reconnectCallback(nil) + if !called { + t.Fatal("reconnect callback was not adapted") + } + + track.SetTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if !prov.trackHandlerCalled { + t.Fatal("SetTrackHandler() was not forwarded") + } + + if err := track.AddTrack(nil); err == nil || err.Error() != "add track: track boom" { + t.Fatalf("AddTrack() error = %v", err) + } +} + +func TestLegacyVideoTrackWrapsConnectCloseErrors(t *testing.T) { + prov := &stubProvider{ + connectErr: errors.New("connect boom"), + closeErr: errors.New("close boom"), + } + track := &legacyVideoTrack{provider: prov} + + if err := track.Connect(context.Background()); err == nil || err.Error() != "connect: connect boom" { + t.Fatalf("Connect() error = %v", err) + } + if err := track.Close(); err == nil || err.Error() != "close: close boom" { + t.Fatalf("Close() error = %v", err) + } +} diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..ddd66a1 --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,419 @@ +package client + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "errors" + "io" + "net" + "testing" + "time" + + cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/muxconn" + "github.com/xtaci/smux" +) + +func TestSetupCipher(t *testing.T) { + keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + cipher, err := setupCipher(keyHex) + if err != nil { + t.Fatalf("setupCipher() error = %v", err) + } + if cipher == nil { + t.Fatal("setupCipher() returned nil cipher") + } +} + +func TestSetupCipherRejectsBadInput(t *testing.T) { + if _, err := setupCipher("zz"); err == nil { + t.Fatal("setupCipher() unexpectedly succeeded for bad hex") + } + if _, err := setupCipher("00"); !errors.Is(err, ErrKeySize) { + t.Fatalf("setupCipher() error = %v, want ErrKeySize", err) + } +} + +func TestSmuxConfig(t *testing.T) { + cfg := smuxConfig() + if cfg.Version != 2 || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { + t.Fatalf("smuxConfig() = %+v", cfg) + } +} + +func TestSocks5Handshake(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + done <- c.socks5Handshake(server) + }() + + if _, err := client.Write([]byte{5, 1, 0}); err != nil { + t.Fatalf("Write() error = %v", err) + } + resp := make([]byte, 2) + if _, err := io.ReadFull(client, resp); err != nil { + t.Fatalf("ReadFull() error = %v", err) + } + + if err := <-done; err != nil { + t.Fatalf("socks5Handshake() error = %v", err) + } + if !bytes.Equal(resp, []byte{5, 0}) { + t.Fatalf("handshake response = %v, want [5 0]", resp) + } +} + +func TestSocks5HandshakeRejectsVersion(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + done <- c.socks5Handshake(server) + }() + + if _, err := client.Write([]byte{4, 1}); err != nil { + t.Fatalf("Write() error = %v", err) + } + + if err := <-done; !errors.Is(err, ErrInvalidSOCKSVersion) { + t.Fatalf("socks5Handshake() error = %v, want %v", err, ErrInvalidSOCKSVersion) + } +} + +func TestSocks5HandshakeReadMethodsError(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + done <- c.socks5Handshake(server) + }() + + if _, err := client.Write([]byte{5, 2, 0}); err != nil { + t.Fatalf("Write() error = %v", err) + } + _ = client.Close() + if err := <-done; err == nil { + t.Fatal("socks5Handshake() unexpectedly succeeded") + } +} + +func TestSocks5RequestIPv4(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan struct { + addr string + port int + err error + }, 1) + go func() { + addr, port, err := c.socks5Request(server) + done <- struct { + addr string + port int + err error + }{addr: addr, port: port, err: err} + }() + + req := []byte{5, 1, 0, 1, 127, 0, 0, 1} + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, 8080) + if _, err := client.Write(append(req, port...)); err != nil { + t.Fatalf("Write() error = %v", err) + } + + res := <-done + if res.err != nil { + t.Fatalf("socks5Request() error = %v", res.err) + } + if res.addr != "127.0.0.1" || res.port != 8080 { + t.Fatalf("socks5Request() = (%q, %d), want (127.0.0.1, 8080)", res.addr, res.port) + } +} + +func TestSocks5RequestDomain(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan struct { + addr string + port int + err error + }, 1) + go func() { + addr, port, err := c.socks5Request(server) + done <- struct { + addr string + port int + err error + }{addr: addr, port: port, err: err} + }() + + req := []byte{5, 1, 0, 3, 11} + req = append(req, []byte("example.com")...) + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, 443) + if _, err := client.Write(append(req, port...)); err != nil { + t.Fatalf("Write() error = %v", err) + } + + res := <-done + if res.err != nil { + t.Fatalf("socks5Request() error = %v", res.err) + } + if res.addr != "example.com" || res.port != 443 { + t.Fatalf("socks5Request() = (%q, %d), want (example.com, 443)", res.addr, res.port) + } +} + +func TestSocks5RequestRejectsCommandAndAddressType(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + _, _, err := c.socks5Request(server) + done <- err + }() + + if _, err := client.Write([]byte{5, 2, 0, 1}); err != nil { + t.Fatalf("Write() error = %v", err) + } + + if err := <-done; !errors.Is(err, ErrUnsupportedSOCKSCommand) { + t.Fatalf("socks5Request() error = %v, want %v", err, ErrUnsupportedSOCKSCommand) + } + + server2, client2 := net.Pipe() + defer func() { + _ = server2.Close() + _ = client2.Close() + }() + + done = make(chan error, 1) + go func() { + _, _, err := c.socks5Request(server2) + done <- err + }() + + if _, err := client2.Write([]byte{5, 1, 0, 9}); err != nil { + t.Fatalf("Write() error = %v", err) + } + + if err := <-done; !errors.Is(err, ErrUnsupportedAddressType) { + t.Fatalf("socks5Request() error = %v, want %v", err, ErrUnsupportedAddressType) + } +} + +func TestSocks5RequestReadPortError(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + _, _, err := c.socks5Request(server) + done <- err + }() + + if _, err := client.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1, 0}); err != nil { + t.Fatalf("Write() error = %v", err) + } + _ = client.Close() + if err := <-done; err == nil { + t.Fatal("socks5Request() unexpectedly succeeded") + } +} + +func TestReplyBuffers(t *testing.T) { + if !bytes.Equal(replySuccess(), []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) { + t.Fatalf("replySuccess() = %v", replySuccess()) + } + if !bytes.Equal(replyHostUnreachable(), []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}) { + t.Fatalf("replyHostUnreachable() = %v", replyHostUnreachable()) + } +} + +func TestReadSocks5AddrReadErrors(t *testing.T) { + c := &Client{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + _, err := c.readSocks5Addr(server, 1) + done <- err + }() + + time.Sleep(10 * time.Millisecond) + _ = client.Close() + if err := <-done; err == nil { + t.Fatal("readSocks5Addr() unexpectedly succeeded") + } +} + +func TestSendConnectRequestOverSmux(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + done := make(chan error, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err != nil { + done <- err + return + } + defer func() { _ = stream.Close() }() + + var req map[string]any + if err := json.NewDecoder(stream).Decode(&req); err != nil { + done <- err + return + } + if req["cmd"] != "connect" || req["clientId"] != "client-1" || req["addr"] != "example.com" { + done <- errors.New("unexpected connect request") + return + } + _, err = stream.Write([]byte{0x00}) + done <- err + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + defer func() { _ = stream.Close() }() + + c := &Client{clientID: "client-1"} + if err := c.sendConnectRequest(stream, "example.com", 443); err != nil { + t.Fatalf("sendConnectRequest() error = %v", err) + } + if err := <-done; err != nil { + t.Fatalf("server side error = %v", err) + } +} + +func TestSendConnectRequestRejectsBadAck(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + go func() { + stream, err := serverSess.AcceptStream() + if err != nil { + return + } + defer func() { _ = stream.Close() }() + _, _ = io.CopyN(io.Discard, stream, 1) + _, _ = stream.Write([]byte{0x01}) + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + defer func() { _ = stream.Close() }() + + c := &Client{clientID: "client-1"} + if err := c.sendConnectRequest(stream, "example.com", 443); !errors.Is(err, ErrRemoteNotReady) { + t.Fatalf("sendConnectRequest() error = %v, want %v", err, ErrRemoteNotReady) + } +} + +type closerLinkStub struct { + closed bool +} + +func (s *closerLinkStub) Connect(context.Context) error { return nil } +func (s *closerLinkStub) Send([]byte) error { return nil } +func (s *closerLinkStub) Close() error { s.closed = true; return nil } +func (s *closerLinkStub) SetReconnectCallback(func()) {} +func (s *closerLinkStub) SetShouldReconnect(func() bool) {} +func (s *closerLinkStub) SetEndedCallback(func(string)) {} +func (s *closerLinkStub) WatchConnection(context.Context) {} +func (s *closerLinkStub) CanSend() bool { return true } + +func TestOnDataWithNilConn(t *testing.T) { + c := &Client{} + c.onData([]byte("ignored")) +} + +func TestShutdownClosesLinkAndConn(t *testing.T) { + cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + ln := &closerLinkStub{} + c := &Client{ + ln: ln, + cipher: cipher, + conn: muxconn.New(ln, cipher), + } + c.shutdown() + if !ln.closed { + t.Fatal("shutdown() did not close link") + } +} diff --git a/internal/crypto/chacha_test.go b/internal/crypto/chacha_test.go new file mode 100644 index 0000000..a0ed9cc --- /dev/null +++ b/internal/crypto/chacha_test.go @@ -0,0 +1,50 @@ +package crypto + +import ( + "bytes" + "errors" + "testing" +) + +func TestNewCipherRejectsWrongKeySize(t *testing.T) { + _, err := NewCipher("short") + if !errors.Is(err, ErrInvalidKeySize) { + t.Fatalf("NewCipher() error = %v, want %v", err, ErrInvalidKeySize) + } +} + +func TestCipherRoundTrip(t *testing.T) { + c, err := NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + + plaintext := []byte("hello world") + ciphertext, err := c.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if bytes.Equal(ciphertext, plaintext) { + t.Fatal("ciphertext unexpectedly matches plaintext") + } + + got, err := c.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("Decrypt() = %q, want %q", got, plaintext) + } +} + +func TestDecryptRejectsShortCiphertext(t *testing.T) { + c, err := NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + + _, err = c.Decrypt([]byte("short")) + if !errors.Is(err, ErrCiphertextTooShort) { + t.Fatalf("Decrypt() error = %v, want %v", err, ErrCiphertextTooShort) + } +} diff --git a/internal/link/direct/direct_test.go b/internal/link/direct/direct_test.go new file mode 100644 index 0000000..ffe291c --- /dev/null +++ b/internal/link/direct/direct_test.go @@ -0,0 +1,137 @@ +package direct + +import ( + "context" + "errors" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/link" + "github.com/openlibrecommunity/olcrtc/internal/transport" +) + +type stubTransport struct { + connectErr error + sendErr error + closeErr error + canSend bool + + connectCalled bool + sendData []byte + watched bool + reconnectCB func() + shouldFn func() bool + endedCB func(string) +} + +func (s *stubTransport) Connect(context.Context) error { + s.connectCalled = true + return s.connectErr +} +func (s *stubTransport) Send(data []byte) error { + s.sendData = append([]byte(nil), data...) + return s.sendErr +} +func (s *stubTransport) Close() error { return s.closeErr } +func (s *stubTransport) SetReconnectCallback(cb func()) { + s.reconnectCB = cb +} +func (s *stubTransport) SetShouldReconnect(fn func() bool) { s.shouldFn = fn } +func (s *stubTransport) SetEndedCallback(cb func(string)) { s.endedCB = cb } +func (s *stubTransport) WatchConnection(context.Context) { s.watched = true } +func (s *stubTransport) CanSend() bool { return s.canSend } +func (s *stubTransport) Features() transport.Features { return transport.Features{} } + +func TestNewForwardsConfigAndMethods(t *testing.T) { + name := "direct-test-forward" + var seen transport.Config + tr := &stubTransport{canSend: true} + transport.Register(name, func(_ context.Context, cfg transport.Config) (transport.Transport, error) { + seen = cfg + return tr, nil + }) + + ln, err := New(context.Background(), link.Config{ + Transport: name, + Carrier: "carrier", + RoomURL: "room", + ClientID: "client", + Name: "peer", + DNSServer: "1.1.1.1:53", + ProxyAddr: "127.0.0.1", + ProxyPort: 1080, + VideoWidth: 640, + VideoHeight: 480, + VideoFPS: 30, + VideoBitrate: "1M", + VideoHW: "none", + VideoQRSize: 4, + VideoQRRecovery: "low", + VideoCodec: "qrcode", + VideoTileModule: 3, + VideoTileRS: 20, + VP8FPS: 25, + VP8BatchSize: 8, + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + if seen.ClientID != "client" || seen.ProxyPort != 1080 || seen.VideoTileRS != 20 || seen.VP8BatchSize != 8 { + t.Fatalf("forwarded config = %+v", seen) + } + + if err := ln.Connect(context.Background()); err != nil { + t.Fatalf("Connect() error = %v", err) + } + if !tr.connectCalled { + t.Fatal("Connect() was not forwarded") + } + + if err := ln.Send([]byte("payload")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if string(tr.sendData) != "payload" { + t.Fatalf("Send() forwarded %q, want payload", tr.sendData) + } + + ln.SetReconnectCallback(func() {}) + ln.SetShouldReconnect(func() bool { return true }) + ln.SetEndedCallback(func(string) {}) + ln.WatchConnection(context.Background()) + if tr.reconnectCB == nil || tr.shouldFn == nil || tr.endedCB == nil || !tr.watched { + t.Fatal("callbacks/watch were not forwarded") + } + if !ln.CanSend() { + t.Fatal("CanSend() = false, want true") + } +} + +func TestNewWrapsFactoryError(t *testing.T) { + name := "direct-test-error" + transport.Register(name, func(context.Context, transport.Config) (transport.Transport, error) { + return nil, errors.New("boom") + }) + + _, err := New(context.Background(), link.Config{Transport: name}) + if err == nil || err.Error() != "create transport for direct link: boom" { + t.Fatalf("New() error = %v", err) + } +} + +func TestDirectLinkWrapsTransportErrors(t *testing.T) { + ln := &directLink{transport: &stubTransport{ + connectErr: errors.New("connect boom"), + sendErr: errors.New("send boom"), + closeErr: errors.New("close boom"), + }} + + if err := ln.Connect(context.Background()); err == nil || err.Error() != "transport connect: connect boom" { + t.Fatalf("Connect() error = %v", err) + } + if err := ln.Send([]byte("x")); err == nil || err.Error() != "transport send: send boom" { + t.Fatalf("Send() error = %v", err) + } + if err := ln.Close(); err == nil || err.Error() != "transport close: close boom" { + t.Fatalf("Close() error = %v", err) + } +} diff --git a/internal/link/link_test.go b/internal/link/link_test.go new file mode 100644 index 0000000..94b24d9 --- /dev/null +++ b/internal/link/link_test.go @@ -0,0 +1,71 @@ +package link + +import ( + "context" + "errors" + "reflect" + "testing" +) + +type stubLink struct{} + +func (s *stubLink) Connect(context.Context) error { return nil } +func (s *stubLink) Send([]byte) error { return nil } +func (s *stubLink) Close() error { return nil } +func (s *stubLink) SetReconnectCallback(func()) {} +func (s *stubLink) SetShouldReconnect(func() bool) {} +func (s *stubLink) SetEndedCallback(func(string)) {} +func (s *stubLink) WatchConnection(context.Context) {} +func (s *stubLink) CanSend() bool { return true } + +func snapshotLinkRegistry() map[string]Factory { + out := make(map[string]Factory, len(registry)) + for k, v := range registry { + out[k] = v + } + return out +} + +func restoreLinkRegistry(src map[string]Factory) { + registry = make(map[string]Factory, len(src)) + for k, v := range src { + registry[k] = v + } +} + +func TestNewAndAvailable(t *testing.T) { + old := snapshotLinkRegistry() + t.Cleanup(func() { restoreLinkRegistry(old) }) + + called := false + Register("test-link", func(_ context.Context, cfg Config) (Link, error) { + called = cfg.ClientID == "client-1" + return &stubLink{}, nil + }) + + got, err := New(context.Background(), "test-link", Config{ClientID: "client-1"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if !called { + t.Fatal("factory did not receive config") + } + if _, ok := got.(*stubLink); !ok { + t.Fatalf("New() returned %T, want *stubLink", got) + } + + if !reflect.DeepEqual(Available(), []string{"test-link"}) { + t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-link"}) + } +} + +func TestNewReturnsErrLinkNotFound(t *testing.T) { + old := snapshotLinkRegistry() + t.Cleanup(func() { restoreLinkRegistry(old) }) + registry = map[string]Factory{} + + _, err := New(context.Background(), "missing", Config{}) + if !errors.Is(err, ErrLinkNotFound) { + t.Fatalf("New() error = %v, want %v", err, ErrLinkNotFound) + } +} diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go new file mode 100644 index 0000000..dfe58a1 --- /dev/null +++ b/internal/logger/logger_test.go @@ -0,0 +1,72 @@ +package logger + +import ( + "bytes" + "log" + "strings" + "testing" +) + +func captureLogs(t *testing.T) *bytes.Buffer { + t.Helper() + var buf bytes.Buffer + oldWriter := log.Writer() + oldFlags := log.Flags() + log.SetOutput(&buf) + log.SetFlags(0) + t.Cleanup(func() { + log.SetOutput(oldWriter) + log.SetFlags(oldFlags) + SetVerbose(false) + }) + return &buf +} + +func TestVerboseFlag(t *testing.T) { + SetVerbose(true) + if !IsVerbose() { + t.Fatal("IsVerbose() = false, want true") + } + SetVerbose(false) + if IsVerbose() { + t.Fatal("IsVerbose() = true, want false") + } +} + +func TestLoggingFunctions(t *testing.T) { + buf := captureLogs(t) + + Info("info") + Infof("%s", "infof") + Warn("warn") + Warnf("%s", "warnf") + Error("error") + Errorf("%s", "errorf") + + got := buf.String() + for _, want := range []string{"info", "infof", "warn", "warnf", "error", "errorf"} { + if !strings.Contains(got, want) { + t.Fatalf("log output %q does not contain %q", got, want) + } + } +} + +func TestVerboseAndDebugLogging(t *testing.T) { + buf := captureLogs(t) + + Verbosef("%s", "hidden") + Debugf("%s", "hidden-debug") + if got := buf.String(); got != "" { + t.Fatalf("unexpected log output when verbose disabled: %q", got) + } + + SetVerbose(true) + Verbosef("%s", "visible") + Debugf("%s", "visible-debug") + got := buf.String() + for _, want := range []string{"visible", "visible-debug"} { + if !strings.Contains(got, want) { + t.Fatalf("log output %q does not contain %q", got, want) + } + } +} diff --git a/internal/muxconn/conn_test.go b/internal/muxconn/conn_test.go new file mode 100644 index 0000000..a2ad165 --- /dev/null +++ b/internal/muxconn/conn_test.go @@ -0,0 +1,198 @@ +package muxconn + +import ( + "bytes" + "context" + "errors" + "io" + "sync" + "testing" + "time" + + cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" +) + +type stubLink struct { + mu sync.Mutex + canSend bool + sendErr error + sent [][]byte + canSendFn func() bool +} + +func (s *stubLink) Connect(context.Context) error { return nil } +func (s *stubLink) Close() error { return nil } +func (s *stubLink) SetReconnectCallback(func()) {} +func (s *stubLink) SetShouldReconnect(func() bool) {} +func (s *stubLink) SetEndedCallback(func(string)) {} +func (s *stubLink) WatchConnection(context.Context) {} +func (s *stubLink) Send(data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sent = append(s.sent, append([]byte(nil), data...)) + return s.sendErr +} +func (s *stubLink) CanSend() bool { + if s.canSendFn != nil { + return s.canSendFn() + } + s.mu.Lock() + defer s.mu.Unlock() + return s.canSend +} + +func newTestCipher(t *testing.T) *cryptopkg.Cipher { + t.Helper() + c, err := cryptopkg.NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + return c +} + +func TestPushAndReadRoundTrip(t *testing.T) { + cipher := newTestCipher(t) + conn := New(&stubLink{canSend: true}, cipher) + + msg1, err := cipher.Encrypt([]byte("hello ")) + if err != nil { + t.Fatalf("Encrypt(msg1) error = %v", err) + } + msg2, err := cipher.Encrypt([]byte("world")) + if err != nil { + t.Fatalf("Encrypt(msg2) error = %v", err) + } + + conn.Push(msg1) + conn.Push(msg2) + + buf := make([]byte, 11) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if got := string(buf[:n]); got != "hello world" { + t.Fatalf("Read() = %q, want %q", got, "hello world") + } +} + +func TestPushIgnoresInvalidCiphertext(t *testing.T) { + cipher := newTestCipher(t) + conn := New(&stubLink{canSend: true}, cipher) + + conn.Push([]byte("bad")) + if err := conn.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + buf := make([]byte, 8) + n, err := conn.Read(buf) + if !errors.Is(err, io.EOF) || n != 0 { + t.Fatalf("Read() = (%d, %v), want (0, EOF)", n, err) + } +} + +func TestWriteEncryptsAndSends(t *testing.T) { + cipher := newTestCipher(t) + ln := &stubLink{canSend: true} + conn := New(ln, cipher) + + n, err := conn.Write([]byte("payload")) + if err != nil { + t.Fatalf("Write() error = %v", err) + } + if n != len("payload") { + t.Fatalf("Write() n = %d, want %d", n, len("payload")) + } + if len(ln.sent) != 1 { + t.Fatalf("sent packets = %d, want 1", len(ln.sent)) + } + + got, err := cipher.Decrypt(ln.sent[0]) + if err != nil { + t.Fatalf("Decrypt(sent) error = %v", err) + } + if !bytes.Equal(got, []byte("payload")) { + t.Fatalf("decrypted payload = %q, want %q", got, "payload") + } +} + +func TestWriteWaitsForCanSend(t *testing.T) { + cipher := newTestCipher(t) + start := time.Now() + readyAt := start.Add(15 * time.Millisecond) + ln := &stubLink{ + canSendFn: func() bool { + return time.Now().After(readyAt) + }, + } + conn := New(ln, cipher) + + if _, err := conn.Write([]byte("payload")); err != nil { + t.Fatalf("Write() error = %v", err) + } + if len(ln.sent) != 1 { + t.Fatalf("sent packets = %d, want 1", len(ln.sent)) + } +} + +func TestWriteReturnsErrClosedWhileWaiting(t *testing.T) { + cipher := newTestCipher(t) + conn := New(&stubLink{canSend: false}, cipher) + + done := make(chan error, 1) + go func() { + _, err := conn.Write([]byte("payload")) + done <- err + }() + + time.Sleep(10 * time.Millisecond) + if err := conn.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, ErrClosed) { + t.Fatalf("Write() error = %v, want %v", err, ErrClosed) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("Write() did not unblock after Close") + } +} + +func TestWriteWrapsSendError(t *testing.T) { + cipher := newTestCipher(t) + conn := New(&stubLink{canSend: true, sendErr: errors.New("boom")}, cipher) + + _, err := conn.Write([]byte("payload")) + if err == nil || err.Error() != "send: boom" { + t.Fatalf("Write() error = %v", err) + } +} + +func TestCloseMakesReadReturnEOF(t *testing.T) { + cipher := newTestCipher(t) + conn := New(&stubLink{canSend: true}, cipher) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4) + n, err := conn.Read(buf) + if !errors.Is(err, io.EOF) || n != 0 { + t.Errorf("Read() = (%d, %v), want (0, EOF)", n, err) + } + }() + + time.Sleep(10 * time.Millisecond) + if err := conn.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatal("Read() did not unblock after Close") + } +} diff --git a/internal/names/names_test.go b/internal/names/names_test.go new file mode 100644 index 0000000..b8cdb38 --- /dev/null +++ b/internal/names/names_test.go @@ -0,0 +1,107 @@ +package names + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestParseEmbedded(t *testing.T) { + got := parseEmbedded(" Alice \n\n Bob\n") + want := []string{"Alice", "Bob"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("parseEmbedded() = %#v, want %#v", got, want) + } +} + +func TestLoadNames(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "names.txt") + if err := os.WriteFile(path, []byte(" Alice \n\nBob\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + got, err := loadNames(path) + if err != nil { + t.Fatalf("loadNames() error = %v", err) + } + want := []string{"Alice", "Bob"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("loadNames() = %#v, want %#v", got, want) + } +} + +func TestLoadNameFilesOverridesGlobals(t *testing.T) { + oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...) + t.Cleanup(func() { + firstNames = oldFirst + lastNames = oldLast + }) + + dir := t.TempDir() + first := filepath.Join(dir, "first.txt") + last := filepath.Join(dir, "last.txt") + if err := os.WriteFile(first, []byte("Neo\n"), 0o600); err != nil { + t.Fatalf("WriteFile(first) error = %v", err) + } + if err := os.WriteFile(last, []byte("Anderson\n"), 0o600); err != nil { + t.Fatalf("WriteFile(last) error = %v", err) + } + + if err := LoadNameFiles(first, last); err != nil { + t.Fatalf("LoadNameFiles() error = %v", err) + } + + if got := Generate(); got != "Neo Anderson" { + t.Fatalf("Generate() = %q, want %q", got, "Neo Anderson") + } +} + +func TestGenerateFallsBackWhenNamesEmpty(t *testing.T) { + oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...) + t.Cleanup(func() { + firstNames = oldFirst + lastNames = oldLast + }) + + firstNames = nil + lastNames = nil + + if got := Generate(); got != "anonymous user" { + t.Fatalf("Generate() = %q, want anonymous user", got) + } +} + +func TestRandomIndexBounds(t *testing.T) { + for i := 0; i < 20; i++ { + got := randomIndex(2) + if got < 0 || got > 1 { + t.Fatalf("randomIndex(2) = %d, out of range", got) + } + } + + if got := randomIndex(0); got != 0 { + t.Fatalf("randomIndex(0) = %d, want 0", got) + } +} + +func TestLoadNameFilesIgnoresMissingFiles(t *testing.T) { + oldFirst, oldLast := append([]string(nil), firstNames...), append([]string(nil), lastNames...) + t.Cleanup(func() { + firstNames = oldFirst + lastNames = oldLast + }) + + firstNames = []string{"Kept"} + lastNames = []string{"Value"} + if err := LoadNameFiles("missing-first", "missing-last"); err != nil { + t.Fatalf("LoadNameFiles() error = %v", err) + } + + got := Generate() + if !strings.Contains(got, "Kept") || !strings.Contains(got, "Value") { + t.Fatalf("Generate() = %q, want preserved names", got) + } +} diff --git a/internal/protect/protect_test.go b/internal/protect/protect_test.go new file mode 100644 index 0000000..dc14c9b --- /dev/null +++ b/internal/protect/protect_test.go @@ -0,0 +1,142 @@ +package protect + +import ( + "context" + "errors" + "net" + "net/http" + "syscall" + "testing" + "time" +) + +type rawConnStub struct { + controlFn func(func(uintptr)) error +} + +func (r rawConnStub) Control(fn func(uintptr)) error { + if r.controlFn != nil { + return r.controlFn(fn) + } + fn(42) + return nil +} +func (r rawConnStub) Read(func(uintptr) bool) error { return nil } +func (r rawConnStub) Write(func(uintptr) bool) error { return nil } + +func TestControlFuncWithoutProtector(t *testing.T) { + old := Protector + Protector = nil + t.Cleanup(func() { Protector = old }) + + if err := controlFunc("tcp4", "", rawConnStub{}); err != nil { + t.Fatalf("controlFunc() error = %v", err) + } +} + +func TestControlFuncWithProtector(t *testing.T) { + old := Protector + t.Cleanup(func() { Protector = old }) + + called := 0 + Protector = func(fd int) bool { + called++ + if fd != 42 { + t.Fatalf("Protector fd = %d, want 42", fd) + } + return true + } + if err := controlFunc("tcp4", "", rawConnStub{}); err != nil { + t.Fatalf("controlFunc() error = %v", err) + } + if called != 1 { + t.Fatalf("Protector calls = %d, want 1", called) + } + + Protector = func(int) bool { return false } + err := controlFunc("tcp4", "", rawConnStub{}) + var opErr *net.OpError + if !errors.As(err, &opErr) || opErr.Op != "protect" { + t.Fatalf("controlFunc() error = %v, want protect op error", err) + } +} + +func TestControlFuncWrapsControlError(t *testing.T) { + old := Protector + Protector = func(int) bool { return true } + t.Cleanup(func() { Protector = old }) + + err := controlFunc("tcp4", "", rawConnStub{ + controlFn: func(func(uintptr)) error { return errors.New("boom") }, + }) + if err == nil || err.Error() != "control failed: boom" { + t.Fatalf("controlFunc() error = %v", err) + } +} + +func TestNewDialerAndHTTPClient(t *testing.T) { + dialer := NewDialer() + if dialer.Timeout != 10*time.Second || dialer.KeepAlive != 30*time.Second || dialer.Control == nil { + t.Fatalf("NewDialer() = %+v", dialer) + } + + client := NewHTTPClient() + tr, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Transport type = %T, want *http.Transport", client.Transport) + } + if tr.DialContext == nil || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 || + tr.IdleConnTimeout != 30*time.Second || tr.TLSHandshakeTimeout != 10*time.Second || + tr.ResponseHeaderTimeout != 10*time.Second { + t.Fatalf("transport = %+v", tr) + } +} + +func TestDialContextAndProxyDialer(t *testing.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer func() { _ = ln.Close() }() + + accepted := make(chan struct{}, 2) + go func() { + for i := 0; i < 2; i++ { + conn, err := ln.Accept() + if err != nil { + return + } + _ = conn.Close() + accepted <- struct{}{} + } + }() + + conn, err := DialContext(context.Background(), "tcp4", ln.Addr().String()) + if err != nil { + t.Fatalf("DialContext() error = %v", err) + } + _ = conn.Close() + + proxyConn, err := NewProxyDialer().Dial("tcp4", ln.Addr().String()) + if err != nil { + t.Fatalf("ProxyDialer.Dial() error = %v", err) + } + _ = proxyConn.Close() + + <-accepted + <-accepted +} + +func TestDialFailuresAreWrapped(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + if _, err := DialContext(ctx, "tcp4", "127.0.0.1:1"); err == nil { + t.Fatal("DialContext() unexpectedly succeeded") + } + if _, err := NewProxyDialer().Dial("tcp4", "127.0.0.1:1"); err == nil { + t.Fatal("ProxyDialer.Dial() unexpectedly succeeded") + } +} + +var _ syscall.RawConn = rawConnStub{} diff --git a/internal/provider/jazz/api.go b/internal/provider/jazz/api.go index 453f94a..9e51432 100644 --- a/internal/provider/jazz/api.go +++ b/internal/provider/jazz/api.go @@ -13,10 +13,9 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/protect" ) -const ( - apiBase = "https://bk.salutejazz.ru" - authTypeAnonymous = "ANONYMOUS" -) +const authTypeAnonymous = "ANONYMOUS" + +var apiBase = "https://bk.salutejazz.ru" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest. // RoomInfo contains connection details for a SaluteJazz room. type RoomInfo struct { diff --git a/internal/provider/jazz/api_test.go b/internal/provider/jazz/api_test.go new file mode 100644 index 0000000..37cbd8b --- /dev/null +++ b/internal/provider/jazz/api_test.go @@ -0,0 +1,141 @@ +package jazz + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func withJazzAPIServer(t *testing.T, h http.Handler) string { + t.Helper() + old := apiBase + srv := httptest.NewServer(h) + t.Cleanup(func() { + apiBase = old + srv.Close() + }) + apiBase = srv.URL + return srv.URL +} + +func TestCreateMeetingAndPreconnect(t *testing.T) { + withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Jazz-AuthType") != authTypeAnonymous { + t.Fatalf("missing auth header: %v", r.Header) + } + switch r.URL.Path { + case "/room/create-meeting": + if r.Method != http.MethodPost { + t.Fatalf("create method = %s", r.Method) + } + _ = json.NewEncoder(w).Encode(createResponse{RoomID: "room-1", Password: "pass"}) + case "/room/room-1/preconnect": + if r.Method != http.MethodPost { + t.Fatalf("preconnect method = %s", r.Method) + } + _ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"}) + default: + http.NotFound(w, r) + } + })) + + headers := map[string]string{ + "X-Jazz-AuthType": authTypeAnonymous, + "Content-Type": "application/json", + } + created, err := createMeeting(context.Background(), headers) + if err != nil { + t.Fatalf("createMeeting() error = %v", err) + } + if created.RoomID != "room-1" || created.Password != "pass" { + t.Fatalf("createMeeting() = %+v", created) + } + + connector, err := preconnect(context.Background(), "room-1", "pass", headers) + if err != nil { + t.Fatalf("preconnect() error = %v", err) + } + if connector != "wss://connector" { + t.Fatalf("preconnect() = %q", connector) + } +} + +func TestCreateRoomAndJoinRoom(t *testing.T) { + withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/room/create-meeting": + _ = json.NewEncoder(w).Encode(createResponse{RoomID: "new-room", Password: "new-pass"}) + case "/room/new-room/preconnect", "/room/existing/preconnect": + _ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"}) + default: + http.NotFound(w, r) + } + })) + + room, err := createRoom(context.Background()) + if err != nil { + t.Fatalf("createRoom() error = %v", err) + } + if room.RoomID != "new-room" || room.Password != "new-pass" || room.ConnectorURL != "wss://connector" { + t.Fatalf("createRoom() = %+v", room) + } + + room, err = joinRoom(context.Background(), "existing", "secret") + if err != nil { + t.Fatalf("joinRoom() error = %v", err) + } + if room.RoomID != "existing" || room.Password != "secret" || room.ConnectorURL != "wss://connector" { + t.Fatalf("joinRoom() = %+v", room) + } +} + +func TestJazzAPIErrors(t *testing.T) { + withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "create-meeting"): + http.Error(w, "bad", http.StatusTeapot) + default: + http.Error(w, "bad", http.StatusInternalServerError) + } + })) + + if _, err := createMeeting(context.Background(), nil); !errors.Is(err, errCreateRoomFailed) { + t.Fatalf("createMeeting() error = %v, want %v", err, errCreateRoomFailed) + } + if _, err := preconnect(context.Background(), "room", "pass", nil); !errors.Is(err, errPreconnectFailed) { + t.Fatalf("preconnect() error = %v, want %v", err, errPreconnectFailed) + } +} + +func TestNewPeerUsesRoomAPI(t *testing.T) { + withJazzAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/room/create-meeting": + _ = json.NewEncoder(w).Encode(createResponse{RoomID: "new-room", Password: "new-pass"}) + case "/room/new-room/preconnect", "/room/existing/preconnect": + _ = json.NewEncoder(w).Encode(map[string]string{"connectorUrl": "wss://connector"}) + default: + http.NotFound(w, r) + } + })) + + created, err := NewPeer(context.Background(), "any", "peer", nil) + if err != nil { + t.Fatalf("NewPeer(create) error = %v", err) + } + if created.roomInfo.RoomID != "new-room" { + t.Fatalf("created room = %+v", created.roomInfo) + } + + joined, err := NewPeer(context.Background(), "existing:secret", "peer", nil) + if err != nil { + t.Fatalf("NewPeer(join) error = %v", err) + } + if joined.roomInfo.RoomID != "existing" || joined.roomInfo.Password != "secret" { + t.Fatalf("joined room = %+v", joined.roomInfo) + } +} diff --git a/internal/provider/jazz/datapacket_test.go b/internal/provider/jazz/datapacket_test.go new file mode 100644 index 0000000..7f87a30 --- /dev/null +++ b/internal/provider/jazz/datapacket_test.go @@ -0,0 +1,70 @@ +package jazz + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestDataPacketRoundTrip(t *testing.T) { + payload := []byte("hello jazz") + raw := EncodeDataPacket(payload) + + got, ok := DecodeDataPacket(raw) + if !ok { + t.Fatal("DecodeDataPacket() ok = false") + } + if !bytes.Equal(got, payload) { + t.Fatalf("DecodeDataPacket() = %q, want %q", got, payload) + } +} + +func TestDecodeDataPacketRejectsMalformedPackets(t *testing.T) { + tests := [][]byte{ + nil, + {0xff}, + encodeField(1, 0, encodeVarint(0)), + {byte(2<<3 | 2), 10, 1}, + {byte(3<<3 | 7), 0}, + } + + for _, raw := range tests { + if payload, ok := DecodeDataPacket(raw); ok { + t.Fatalf("DecodeDataPacket(%v) = (%q, true), want false", raw, payload) + } + } +} + +func TestParseFieldsSkipsSupportedNonTargetWireTypes(t *testing.T) { + data := encodeField(1, 0, encodeVarint(150)) + data = append(data, encodeField(3, 1, []byte("12345678"))...) + data = append(data, encodeField(4, 5, []byte("1234"))...) + data = append(data, encodeField(2, 2, []byte("target"))...) + + got, ok := parseFields(data, 2) + if !ok || string(got) != "target" { + t.Fatalf("parseFields() = (%q, %v), want target", got, ok) + } +} + +func TestByteReader(t *testing.T) { + r := &byteReader{data: []byte{1, 2, 3}} + b, err := r.ReadByte() + if err != nil || b != 1 { + t.Fatalf("ReadByte() = (%d, %v), want (1, nil)", b, err) + } + + buf := make([]byte, 4) + n, err := r.Read(buf) + if err != nil || n != 2 || !bytes.Equal(buf[:n], []byte{2, 3}) { + t.Fatalf("Read() = (%d, %v, %v), want two bytes", n, err, buf[:n]) + } + + if _, err := r.ReadByte(); !errors.Is(err, io.EOF) { + t.Fatalf("ReadByte() error = %v, want EOF", err) + } + if n, err := r.Read(buf); !errors.Is(err, io.EOF) || n != 0 { + t.Fatalf("Read() = (%d, %v), want (0, EOF)", n, err) + } +} diff --git a/internal/provider/jazz/peer_helpers_test.go b/internal/provider/jazz/peer_helpers_test.go new file mode 100644 index 0000000..24729fc --- /dev/null +++ b/internal/provider/jazz/peer_helpers_test.go @@ -0,0 +1,112 @@ +package jazz + +import ( + "context" + "errors" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/pion/webrtc/v4" +) + +func TestPeerStateHelpers(t *testing.T) { + p := &Peer{ + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sessionCloseCh: make(chan struct{}), + sendQueue: make(chan []byte, 1), + subscriberConn: make(chan struct{}), + publisherConn: make(chan struct{}), + } + + p.resetMediaState() + if p.subscriberReady.Load() || p.publisherReady.Load() || p.subscriberConn == nil || p.publisherConn == nil { + t.Fatal("resetMediaState() did not reset readiness") + } + if p.hasLocalVideoTracks() { + t.Fatal("hasLocalVideoTracks() = true without tracks") + } + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if !p.hasLocalVideoTracks() { + t.Fatal("hasLocalVideoTracks() = false after AddVideoTrack") + } + + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if p.videoTrackHandler() == nil { + t.Fatal("videoTrackHandler() = nil") + } + + cfg := defaultWebRTCConfig() + if cfg.SDPSemantics != webrtc.SDPSemanticsUnifiedPlan || cfg.BundlePolicy != webrtc.BundlePolicyMaxBundle { + t.Fatalf("defaultWebRTCConfig() = %+v", cfg) + } + if p.buildAPI() == nil { + t.Fatal("buildAPI() returned nil") + } +} + +func TestPeerCallbacksQueueReconnectAndClose(t *testing.T) { + p := &Peer{ + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sessionCloseCh: make(chan struct{}), + sendQueue: make(chan []byte, 1), + } + + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + p.SetShouldReconnect(func() bool { return true }) + p.SetEndedCallback(func(string) {}) + if p.onReconnect == nil || p.shouldReconnect == nil || p.onEnded == nil { + t.Fatal("callbacks were not stored") + } + + p.queueReconnect() + select { + case <-p.reconnectCh: + default: + t.Fatal("queueReconnect() did not enqueue") + } + + p.SetShouldReconnect(func() bool { return false }) + p.queueReconnect() + select { + case <-p.reconnectCh: + t.Fatal("queueReconnect() enqueued despite policy=false") + default: + } + + done := make(chan struct{}) + go func() { + p.WatchConnection(context.Background()) + close(done) + }() + if err := p.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + <-done + if err := p.Send([]byte("closed")); !errors.Is(err, provider.ErrDataChannelNotReady) { + t.Fatalf("Send() error = %v, want datachannel not ready", err) + } +} + +func TestPeerCanSendVideoOnlyModes(t *testing.T) { + p := &Peer{sendQueue: make(chan []byte, 1)} + p.subscriberReady.Store(true) + if !p.CanSend() { + t.Fatal("CanSend() = false for subscriber-ready peer without local video") + } + _ = p.AddVideoTrack(nil) + if p.CanSend() { + t.Fatal("CanSend() = true with local video but publisher not ready") + } + p.publisherReady.Store(true) + if !p.CanSend() { + t.Fatal("CanSend() = false with subscriber and publisher ready") + } + p.closed.Store(true) + if p.CanSend() { + t.Fatal("CanSend() = true for closed peer") + } +} diff --git a/internal/provider/jazz/provider_test.go b/internal/provider/jazz/provider_test.go new file mode 100644 index 0000000..ab6741c --- /dev/null +++ b/internal/provider/jazz/provider_test.go @@ -0,0 +1,51 @@ +package jazz + +import ( + "context" + "errors" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/pion/webrtc/v4" +) + +func TestJazzProviderForwardsPeerMethods(t *testing.T) { + peer := &Peer{ + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sessionCloseCh: make(chan struct{}), + sendQueue: make(chan []byte, 1), + } + p := &jazzProvider{peer: peer} + + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + p.SetShouldReconnect(func() bool { return true }) + p.SetEndedCallback(func(string) {}) + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil { + t.Fatal("callbacks were not forwarded") + } + + if p.GetSendQueue() != peer.sendQueue { + t.Fatal("GetSendQueue() did not forward") + } + if p.GetBufferedAmount() != 0 { + t.Fatal("GetBufferedAmount() != 0 with nil datachannel") + } + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if err := p.Send([]byte("x")); !errors.Is(err, provider.ErrDataChannelNotReady) { + t.Fatalf("Send() error = %v, want datachannel not ready", err) + } + + done := make(chan struct{}) + go func() { + p.WatchConnection(context.Background()) + close(done) + }() + if err := p.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + <-done +} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go new file mode 100644 index 0000000..3e080cb --- /dev/null +++ b/internal/provider/provider_test.go @@ -0,0 +1,75 @@ +package provider + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/pion/webrtc/v4" +) + +type stubProvider struct{} + +func (s *stubProvider) Connect(context.Context) error { return nil } +func (s *stubProvider) Send([]byte) error { return nil } +func (s *stubProvider) Close() error { return nil } +func (s *stubProvider) SetReconnectCallback(func(*webrtc.DataChannel)) {} +func (s *stubProvider) SetShouldReconnect(func() bool) {} +func (s *stubProvider) SetEndedCallback(func(string)) {} +func (s *stubProvider) WatchConnection(context.Context) {} +func (s *stubProvider) CanSend() bool { return true } +func (s *stubProvider) GetSendQueue() chan []byte { return nil } +func (s *stubProvider) GetBufferedAmount() uint64 { return 0 } + +func snapshotProviderRegistry() map[string]Factory { + out := make(map[string]Factory, len(registry)) + for k, v := range registry { + out[k] = v + } + return out +} + +func restoreProviderRegistry(src map[string]Factory) { + registry = make(map[string]Factory, len(src)) + for k, v := range src { + registry[k] = v + } +} + +func TestNewAndAvailable(t *testing.T) { + old := snapshotProviderRegistry() + t.Cleanup(func() { restoreProviderRegistry(old) }) + + called := false + Register("test-provider", func(_ context.Context, cfg Config) (Provider, error) { + called = cfg.Name == "peer" + return &stubProvider{}, nil + }) + + got, err := New(context.Background(), "test-provider", Config{Name: "peer"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if !called { + t.Fatal("factory did not receive config") + } + if _, ok := got.(*stubProvider); !ok { + t.Fatalf("New() returned %T, want *stubProvider", got) + } + + if !reflect.DeepEqual(Available(), []string{"test-provider"}) { + t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-provider"}) + } +} + +func TestNewReturnsErrProviderNotFound(t *testing.T) { + old := snapshotProviderRegistry() + t.Cleanup(func() { restoreProviderRegistry(old) }) + registry = map[string]Factory{} + + _, err := New(context.Background(), "missing", Config{}) + if !errors.Is(err, ErrProviderNotFound) { + t.Fatalf("New() error = %v, want %v", err, ErrProviderNotFound) + } +} diff --git a/internal/provider/telemost/api.go b/internal/provider/telemost/api.go index 00b1045..20cca40 100644 --- a/internal/provider/telemost/api.go +++ b/internal/provider/telemost/api.go @@ -13,14 +13,14 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/protect" ) -const apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost" +var apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest. var ErrAPI = errors.New("api error") //nolint:revive type ConnectionInfo struct { //nolint:revive - RoomID string `json:"room_id"` //nolint:tagliatelle - PeerID string `json:"peer_id"` //nolint:tagliatelle - Credentials string `json:"credentials"` //nolint:tagliatelle + RoomID string `json:"room_id"` //nolint:tagliatelle + PeerID string `json:"peer_id"` //nolint:tagliatelle + Credentials string `json:"credentials"` //nolint:tagliatelle ClientConfig struct { MediaServerURL string `json:"media_server_url"` //nolint:tagliatelle } `json:"client_configuration"` //nolint:tagliatelle diff --git a/internal/provider/telemost/api_test.go b/internal/provider/telemost/api_test.go new file mode 100644 index 0000000..d072cfc --- /dev/null +++ b/internal/provider/telemost/api_test.go @@ -0,0 +1,83 @@ +package telemost + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func withTelemostAPIServer(t *testing.T, h http.Handler) { + t.Helper() + old := apiBase + srv := httptest.NewServer(h) + t.Cleanup(func() { + apiBase = old + srv.Close() + }) + apiBase = srv.URL +} + +func TestGetConnectionInfo(t *testing.T) { + withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("method = %s", r.Method) + } + if !strings.Contains(r.URL.EscapedPath(), "/conferences/room%2Fid/connection") { + t.Fatalf("path = %q escaped=%q", r.URL.Path, r.URL.EscapedPath()) + } + if r.URL.Query().Get("display_name") != "peer" { + t.Fatalf("display_name query = %q", r.URL.Query().Get("display_name")) + } + _ = json.NewEncoder(w).Encode(ConnectionInfo{ + RoomID: "room", + PeerID: "peer-id", + Credentials: "creds", + }) + })) + + info, err := GetConnectionInfo(context.Background(), "room/id", "peer") + if err != nil { + t.Fatalf("GetConnectionInfo() error = %v", err) + } + if info.RoomID != "room" || info.PeerID != "peer-id" || info.Credentials != "creds" { + t.Fatalf("GetConnectionInfo() = %+v", info) + } +} + +func TestGetConnectionInfoErrors(t *testing.T) { + withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad", http.StatusForbidden) + })) + if _, err := GetConnectionInfo(context.Background(), "room", "peer"); !errors.Is(err, ErrAPI) { + t.Fatalf("GetConnectionInfo() error = %v, want %v", err, ErrAPI) + } + + withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{")) + })) + if _, err := GetConnectionInfo(context.Background(), "room", "peer"); err == nil { + t.Fatal("GetConnectionInfo() unexpectedly accepted bad json") + } +} + +func TestTelemostNewPeerUsesConnectionInfo(t *testing.T) { + withTelemostAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(ConnectionInfo{ + RoomID: "room", + PeerID: "peer-id", + Credentials: "creds", + }) + })) + + p, err := NewPeer(context.Background(), "room", "name", nil) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + if p.roomURL != "room" || p.name != "name" || p.conn.PeerID != "peer-id" || p.sendQueue == nil { + t.Fatalf("NewPeer() = %+v", p) + } +} diff --git a/internal/provider/telemost/peer_helpers_test.go b/internal/provider/telemost/peer_helpers_test.go new file mode 100644 index 0000000..2c9d6f4 --- /dev/null +++ b/internal/provider/telemost/peer_helpers_test.go @@ -0,0 +1,195 @@ +package telemost + +import ( + "testing" + "time" + + "github.com/pion/webrtc/v4" +) + +func TestCloseSignal(t *testing.T) { + closeSignal(nil) + + ch := make(chan struct{}) + closeSignal(ch) + select { + case <-ch: + default: + t.Fatal("closeSignal() did not close channel") + } + closeSignal(ch) +} + +func TestTrafficShapeAndDelay(t *testing.T) { + p := &Peer{} + p.SetTrafficShape(TrafficShape{MaxMessageSize: -1, MinDelay: 5 * time.Millisecond, MaxDelay: 2 * time.Millisecond}) + if p.trafficShape.MaxMessageSize != realDataChannelMessageLimit { + t.Fatalf("MaxMessageSize = %d, want default", p.trafficShape.MaxMessageSize) + } + if p.trafficShape.MaxDelay != p.trafficShape.MinDelay { + t.Fatalf("MaxDelay = %v, want %v", p.trafficShape.MaxDelay, p.trafficShape.MinDelay) + } + if got := p.calculateDelay(); got != 5*time.Millisecond { + t.Fatalf("calculateDelay() = %v, want 5ms", got) + } + + p.SetTrafficShape(TrafficShape{MaxMessageSize: 10, MinDelay: time.Millisecond, MaxDelay: 4 * time.Millisecond}) + for i := 0; i < 20; i++ { + got := p.calculateDelay() + if got < time.Millisecond || got >= 4*time.Millisecond { + t.Fatalf("calculateDelay() = %v, out of range", got) + } + } +} + +func TestICEParsingFiltersTURN(t *testing.T) { + if isNonTURNURL("") || isNonTURNURL("turn:host") || isNonTURNURL("turns:host") { + t.Fatal("isNonTURNURL accepted empty or TURN URL") + } + if !isNonTURNURL("stun:host") { + t.Fatal("isNonTURNURL rejected STUN URL") + } + + urls := parseICEURLs(map[string]interface{}{"urls": []interface{}{"turn:x", "stun:a", 123, "turns:y"}}) + if len(urls) != 1 || urls[0] != "stun:a" { + t.Fatalf("parseICEURLs(interface) = %v, want [stun:a]", urls) + } + + urls = parseICEURLs(map[string]interface{}{"urls": []string{"stun:a", "turn:b"}}) + if len(urls) != 1 || urls[0] != "stun:a" { + t.Fatalf("parseICEURLs(strings) = %v, want [stun:a]", urls) + } +} + +func TestParseICEServer(t *testing.T) { + if _, ok := parseICEServer("bad"); ok { + t.Fatal("parseICEServer() accepted non-map") + } + if _, ok := parseICEServer(map[string]interface{}{"urls": []interface{}{"turn:x"}}); ok { + t.Fatal("parseICEServer() accepted TURN-only server") + } + + ice, ok := parseICEServer(map[string]interface{}{ + "urls": []interface{}{"stun:a", "turn:b"}, + "username": "user", + "credential": "pass", + }) + if !ok { + t.Fatal("parseICEServer() ok = false") + } + if len(ice.URLs) != 1 || ice.URLs[0] != "stun:a" || ice.Username != "user" || ice.Credential != "pass" { + t.Fatalf("parseICEServer() = %+v", ice) + } +} + +func TestConferenceEndParsing(t *testing.T) { + for _, msg := range []map[string]interface{}{ + {"conferenceClosed": true}, + {"conference": map[string]interface{}{"state": "ENDED"}}, + {"conferenceState": map[string]interface{}{"state": "terminated"}}, + } { + if !isConferenceEndMessage(msg) { + t.Fatalf("isConferenceEndMessage(%v) = false", msg) + } + } + if isConferenceEndMessage(map[string]interface{}{"conference": map[string]interface{}{"state": "open"}}) { + t.Fatal("isConferenceEndMessage() accepted active conference") + } + + for _, state := range []string{"closed", "ended", "finished", "terminated"} { + if !isEndedState(state) { + t.Fatalf("isEndedState(%q) = false", state) + } + } + if isEndedState("active") { + t.Fatal("isEndedState(active) = true") + } +} + +func TestPeerSmallStateHelpers(t *testing.T) { + p := &Peer{ + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, 2), + ackWaiters: make(map[string]chan struct{}), + } + p.SetEndedCallback(func(string) {}) + if p.onEnded == nil { + t.Fatal("SetEndedCallback() did not store callback") + } + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + if p.onReconnect == nil { + t.Fatal("SetReconnectCallback() did not store callback") + } + p.SetShouldReconnect(func() bool { return true }) + if p.shouldReconnect == nil || !p.shouldReconnect() { + t.Fatal("SetShouldReconnect() did not store callback") + } + + p.subscriberReady.Store(true) + if !p.CanSend() { + t.Fatal("CanSend() = false for subscriber-only ready peer") + } + p.closed.Store(true) + if p.CanSend() { + t.Fatal("CanSend() = true for closed peer") + } + + ch := p.registerAckWaiter("uid-1") + p.resolveAck("uid-1") + select { + case <-ch: + default: + t.Fatal("resolveAck() did not close waiter") + } + if p.waitForAck("", make(chan struct{}), time.Millisecond) { + t.Fatal("waitForAck(empty uid) = true") + } + + ch = p.registerAckWaiter("uid-2") + go p.resolveAck("uid-2") + if !p.waitForAck("uid-2", ch, time.Second) { + t.Fatal("waitForAck() = false after resolveAck") + } + + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if !p.hasLocalVideoTracks() { + t.Fatal("hasLocalVideoTracks() = false after AddVideoTrack") + } + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if p.videoTrackHandler() == nil { + t.Fatal("videoTrackHandler() = nil") + } +} + +func TestTelemetryCfgParsing(t *testing.T) { + if _, _, ok := parseTelemetryCfg(map[string]interface{}{}); ok { + t.Fatal("parseTelemetryCfg() accepted missing config") + } + if _, _, ok := parseTelemetryCfg(map[string]interface{}{ + "telemetryConfiguration": map[string]interface{}{}, + }); ok { + t.Fatal("parseTelemetryCfg() accepted missing endpoint") + } + + endpoint, interval, ok := parseTelemetryCfg(map[string]interface{}{ + "telemetryConfiguration": map[string]interface{}{ + "endpoint": "https://example.test/log", + "sendingInterval": float64(250), + }, + }) + if !ok || endpoint != "https://example.test/log" || interval != 250*time.Millisecond { + t.Fatalf("parseTelemetryCfg() = (%q, %v, %v)", endpoint, interval, ok) + } + + endpoint, interval, ok = parseTelemetryCfg(map[string]interface{}{ + "telemetryConfiguration": map[string]interface{}{ + "url": "https://example.test/url", + }, + }) + if !ok || endpoint != "https://example.test/url" || interval != defaultTelemetryInterval { + t.Fatalf("parseTelemetryCfg(default) = (%q, %v, %v)", endpoint, interval, ok) + } +} diff --git a/internal/provider/telemost/provider_test.go b/internal/provider/telemost/provider_test.go new file mode 100644 index 0000000..d70c4e4 --- /dev/null +++ b/internal/provider/telemost/provider_test.go @@ -0,0 +1,54 @@ +package telemost + +import ( + "context" + "errors" + "testing" + + "github.com/pion/webrtc/v4" +) + +func TestTelemostProviderForwardsPeerMethods(t *testing.T) { + peer := &Peer{ + reconnectCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + sendQueue: make(chan []byte, 1), + ackWaiters: make(map[string]chan struct{}), + } + p := &telemostProvider{peer: peer} + + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + p.SetShouldReconnect(func() bool { return true }) + p.SetEndedCallback(func(string) {}) + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil { + t.Fatal("callbacks were not forwarded") + } + + if p.GetSendQueue() != peer.sendQueue { + t.Fatal("GetSendQueue() did not forward") + } + if p.GetBufferedAmount() != 0 { + t.Fatal("GetBufferedAmount() != 0 with nil datachannel") + } + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if p.CanSend() { + t.Fatal("CanSend() = true for unready peer") + } + + done := make(chan struct{}) + go func() { + p.WatchConnection(context.Background()) + close(done) + }() + if err := p.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + <-done + + if err := p.Send([]byte("x")); !errors.Is(err, ErrDataChannelNotReady) { + t.Fatalf("Send() error = %v, want datachannel not ready", err) + } +} diff --git a/internal/provider/telemost/state_helpers_test.go b/internal/provider/telemost/state_helpers_test.go new file mode 100644 index 0000000..f072429 --- /dev/null +++ b/internal/provider/telemost/state_helpers_test.go @@ -0,0 +1,84 @@ +package telemost + +import ( + "testing" + "time" +) + +func TestSessionReconnectAndEndedHelpers(t *testing.T) { + p := &Peer{ + reconnectCh: make(chan struct{}, 2), + closeCh: make(chan struct{}), + keepAliveCh: make(chan struct{}), + sessionCloseCh: make(chan struct{}), + telemetryCh: make(chan struct{}, 1), + } + + keepAliveCh, sessionCloseCh := p.resetSession() + if keepAliveCh == nil || sessionCloseCh == nil || keepAliveCh != p.keepAliveCh || sessionCloseCh != p.sessionCloseCh { + t.Fatal("resetSession() did not replace session channels") + } + + p.subscriberReady.Store(true) + p.publisherReady.Store(true) + p.resetMediaState() + if p.subscriberReady.Load() || p.publisherReady.Load() || p.subscriberConn == nil || p.publisherConn == nil { + t.Fatal("resetMediaState() did not reset readiness") + } + + p.queueReconnect() + select { + case <-p.reconnectCh: + default: + t.Fatal("queueReconnect() did not enqueue") + } + + p.SetShouldReconnect(func() bool { return false }) + p.queueReconnect() + select { + case <-p.reconnectCh: + t.Fatal("queueReconnect() enqueued despite policy=false") + default: + } + + p.reconnectCh <- struct{}{} + p.reconnectCh <- struct{}{} + p.drainReconnectQueue() + select { + case <-p.reconnectCh: + t.Fatal("drainReconnectQueue() left queued item") + default: + } + + p.telemetryActive.Store(true) + p.stopTelemetry() + select { + case <-p.telemetryCh: + default: + t.Fatal("stopTelemetry() did not signal active telemetry") + } + + ended := "" + p.SetEndedCallback(func(reason string) { ended = reason }) + p.signalEnded("done") + if !p.closed.Load() || ended != "done" { + t.Fatalf("signalEnded() closed=%v reason=%q", p.closed.Load(), ended) + } +} + +func TestWaitForAckTimeoutAndClose(t *testing.T) { + p := &Peer{ + closeCh: make(chan struct{}), + ackWaiters: make(map[string]chan struct{}), + } + ch := p.registerAckWaiter("timeout") + if p.waitForAck("timeout", ch, time.Millisecond) { + t.Fatal("waitForAck(timeout) = true") + } + + ch = p.registerAckWaiter("closed") + close(p.closeCh) + if p.waitForAck("closed", ch, time.Second) { + t.Fatal("waitForAck(closeCh) = true") + } +} diff --git a/internal/provider/wbstream/api.go b/internal/provider/wbstream/api.go index ac66071..988ec5c 100644 --- a/internal/provider/wbstream/api.go +++ b/internal/provider/wbstream/api.go @@ -12,7 +12,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/protect" ) -const apiBase = "https://stream.wb.ru" +var apiBase = "https://stream.wb.ru" //nolint:gochecknoglobals // Tests redirect HTTP API calls to httptest. var ( errGuestRegister = errors.New("guest register failed") diff --git a/internal/provider/wbstream/api_test.go b/internal/provider/wbstream/api_test.go new file mode 100644 index 0000000..99ef8b1 --- /dev/null +++ b/internal/provider/wbstream/api_test.go @@ -0,0 +1,123 @@ +package wbstream + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func withWBAPIServer(t *testing.T, h http.Handler) { + t.Helper() + old := apiBase + srv := httptest.NewServer(h) + t.Cleanup(func() { + apiBase = old + srv.Close() + }) + apiBase = srv.URL +} + +func TestWBStreamAPIHappyPath(t *testing.T) { + withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/auth/api/v1/auth/user/guest-register": + if r.Method != http.MethodPost { + t.Fatalf("guest method = %s", r.Method) + } + _ = json.NewEncoder(w).Encode(guestRegisterResponse{AccessToken: "access"}) + case "/api-room/api/v2/room": + if r.Header.Get("Authorization") != "Bearer access" { + t.Fatalf("room auth = %q", r.Header.Get("Authorization")) + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(createRoomResponse{RoomID: "room"}) + case "/api-room/api/v1/room/room/join": + w.WriteHeader(http.StatusOK) + case "/api-room-manager/api/v1/room/room/token": + if r.URL.Query().Get("displayName") != "peer" { + t.Fatalf("displayName query = %q", r.URL.Query().Get("displayName")) + } + _ = json.NewEncoder(w).Encode(tokenResponse{RoomToken: "token"}) + default: + http.NotFound(w, r) + } + })) + + access, err := registerGuest(context.Background(), "peer") + if err != nil { + t.Fatalf("registerGuest() error = %v", err) + } + if access != "access" { + t.Fatalf("registerGuest() = %q", access) + } + + room, err := createRoom(context.Background(), access) + if err != nil { + t.Fatalf("createRoom() error = %v", err) + } + if room != "room" { + t.Fatalf("createRoom() = %q", room) + } + + if err := joinRoom(context.Background(), access, room); err != nil { + t.Fatalf("joinRoom() error = %v", err) + } + token, err := getToken(context.Background(), access, room, "peer") + if err != nil { + t.Fatalf("getToken() error = %v", err) + } + if token != "token" { + t.Fatalf("getToken() = %q", token) + } +} + +func TestWBStreamAPIErrors(t *testing.T) { + withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad", http.StatusBadGateway) + })) + + if _, err := registerGuest(context.Background(), "peer"); !errors.Is(err, errGuestRegister) { + t.Fatalf("registerGuest() error = %v, want %v", err, errGuestRegister) + } + if _, err := createRoom(context.Background(), "access"); !errors.Is(err, errCreateRoom) { + t.Fatalf("createRoom() error = %v, want %v", err, errCreateRoom) + } + if err := joinRoom(context.Background(), "access", "room"); !errors.Is(err, errJoinRoom) { + t.Fatalf("joinRoom() error = %v, want %v", err, errJoinRoom) + } + if _, err := getToken(context.Background(), "access", "room", "peer"); !errors.Is(err, errGetToken) { + t.Fatalf("getToken() error = %v, want %v", err, errGetToken) + } +} + +func TestWBStreamGetRoomToken(t *testing.T) { + withWBAPIServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/auth/api/v1/auth/user/guest-register": + _ = json.NewEncoder(w).Encode(guestRegisterResponse{AccessToken: "access"}) + case "/api-room/api/v2/room": + _ = json.NewEncoder(w).Encode(createRoomResponse{RoomID: "created"}) + case "/api-room/api/v1/room/created/join": + w.WriteHeader(http.StatusOK) + case "/api-room-manager/api/v1/room/created/token": + _ = json.NewEncoder(w).Encode(tokenResponse{RoomToken: "token"}) + default: + http.NotFound(w, r) + } + })) + + p, err := NewPeer(context.Background(), "any", "peer", nil) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + token, err := p.getRoomToken(context.Background()) + if err != nil { + t.Fatalf("getRoomToken() error = %v", err) + } + if token != "token" { + t.Fatalf("getRoomToken() = %q", token) + } +} diff --git a/internal/provider/wbstream/peer_test.go b/internal/provider/wbstream/peer_test.go new file mode 100644 index 0000000..17a7df4 --- /dev/null +++ b/internal/provider/wbstream/peer_test.go @@ -0,0 +1,76 @@ +package wbstream + +import ( + "context" + "errors" + "testing" + + "github.com/pion/webrtc/v4" +) + +func TestNewPeerAndSimpleAccessors(t *testing.T) { + p, err := NewPeer(context.Background(), "room", "name", func([]byte) {}) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + if p.roomURL != "room" || p.name != "name" || p.sendQueue == nil || p.done == nil { + t.Fatalf("NewPeer() = %+v", p) + } + if p.GetSendQueue() != p.sendQueue { + t.Fatal("GetSendQueue() did not return sendQueue") + } + if p.GetBufferedAmount() != 0 { + t.Fatal("GetBufferedAmount() != 0") + } + if p.CanSend() { + t.Fatal("CanSend() = true without room") + } +} + +func TestSendQueueAndClose(t *testing.T) { + p, err := NewPeer(context.Background(), "room", "name", nil) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + p.sendQueue = make(chan []byte, 1) + + if err := p.Send([]byte("one")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if err := p.Send([]byte("two")); !errors.Is(err, ErrSendQueueFull) { + t.Fatalf("Send() error = %v, want %v", err, ErrSendQueueFull) + } + if err := p.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if err := p.Send([]byte("closed")); !errors.Is(err, ErrPeerClosed) { + t.Fatalf("Send() error = %v, want %v", err, ErrPeerClosed) + } + if err := p.Close(); err != nil { + t.Fatalf("second Close() error = %v", err) + } +} + +func TestCallbacksAndVideoTrackStorage(t *testing.T) { + p, err := NewPeer(context.Background(), "room", "name", nil) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + p.SetShouldReconnect(func() bool { return true }) + p.SetEndedCallback(func(string) {}) + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + p.WatchConnection(context.Background()) + + if p.onReconnect == nil || p.shouldReconnect == nil || p.onEnded == nil || p.onVideoTrack == nil { + t.Fatal("callbacks were not stored") + } + + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if len(p.videoTracks) != 1 { + t.Fatalf("videoTracks len = %d, want 1", len(p.videoTracks)) + } +} diff --git a/internal/provider/wbstream/provider_test.go b/internal/provider/wbstream/provider_test.go new file mode 100644 index 0000000..f33f6f7 --- /dev/null +++ b/internal/provider/wbstream/provider_test.go @@ -0,0 +1,49 @@ +package wbstream + +import ( + "context" + "errors" + "testing" + + "github.com/pion/webrtc/v4" +) + +func TestWBStreamProviderForwardsPeerMethods(t *testing.T) { + peer, err := NewPeer(context.Background(), "room", "name", nil) + if err != nil { + t.Fatalf("NewPeer() error = %v", err) + } + p := &wbStreamProvider{peer: peer} + + p.SetReconnectCallback(func(*webrtc.DataChannel) {}) + p.SetShouldReconnect(func() bool { return true }) + p.SetEndedCallback(func(string) {}) + p.SetVideoTrackHandler(func(*webrtc.TrackRemote, *webrtc.RTPReceiver) {}) + if peer.onReconnect == nil || peer.shouldReconnect == nil || peer.onEnded == nil || peer.onVideoTrack == nil { + t.Fatal("callbacks were not forwarded") + } + + if p.GetSendQueue() != peer.sendQueue { + t.Fatal("GetSendQueue() did not forward") + } + if p.GetBufferedAmount() != 0 { + t.Fatal("GetBufferedAmount() != 0") + } + if err := p.AddVideoTrack(nil); err != nil { + t.Fatalf("AddVideoTrack(nil) error = %v", err) + } + if p.CanSend() { + t.Fatal("CanSend() = true without LiveKit room") + } + p.WatchConnection(context.Background()) + + if err := p.Send([]byte("x")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if err := p.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if err := p.Send([]byte("x")); !errors.Is(err, ErrPeerClosed) { + t.Fatalf("Send() error = %v, want peer closed", err) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..bbe2ab3 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,343 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "strings" + "testing" + + cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/muxconn" + "github.com/xtaci/smux" +) + +func TestSetupCipher(t *testing.T) { + keyHex := "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + cipher, err := setupCipher(keyHex) + if err != nil { + t.Fatalf("setupCipher() error = %v", err) + } + if cipher == nil { + t.Fatal("setupCipher() returned nil cipher") + } +} + +func TestSetupCipherRejectsBadInput(t *testing.T) { + if _, err := setupCipher(""); !errors.Is(err, ErrKeyRequired) { + t.Fatalf("setupCipher() error = %v, want %v", err, ErrKeyRequired) + } + if _, err := setupCipher("zz"); err == nil { + t.Fatal("setupCipher() unexpectedly succeeded for bad hex") + } + if _, err := setupCipher("00"); !errors.Is(err, ErrKeySize) { + t.Fatalf("setupCipher() error = %v, want ErrKeySize", err) + } +} + +func TestSmuxConfig(t *testing.T) { + cfg := smuxConfig() + if cfg.Version != 2 || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { + t.Fatalf("smuxConfig() = %+v", cfg) + } +} + +func TestParseConnectRequest(t *testing.T) { + buf, err := json.Marshal(ConnectRequest{ + Cmd: "connect", + ClientID: "client-1", + Addr: "example.com", + Port: 443, + }) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + req, ok := parseConnectRequest(buf) + if !ok { + t.Fatal("parseConnectRequest() returned ok=false") + } + if req.ClientID != "client-1" || req.Addr != "example.com" || req.Port != 443 { + t.Fatalf("parseConnectRequest() = %+v", req) + } + + if _, ok := parseConnectRequest([]byte("not-json")); ok { + t.Fatal("parseConnectRequest() unexpectedly accepted invalid json") + } + if _, ok := parseConnectRequest([]byte(`{"cmd":"other"}`)); ok { + t.Fatal("parseConnectRequest() unexpectedly accepted wrong command") + } +} + +func TestAuthorizeRequest(t *testing.T) { + s := &Server{clientID: "client-1"} + if !s.authorizeRequest(ConnectRequest{ClientID: "client-1"}) { + t.Fatal("authorizeRequest() rejected valid client") + } + if s.authorizeRequest(ConnectRequest{ClientID: "client-2"}) { + t.Fatal("authorizeRequest() accepted wrong client") + } +} + +func TestSocks5ConnectSuccess(t *testing.T) { + s := &Server{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + done <- s.socks5Connect(server, "example.com", 443) + }() + + auth := make([]byte, 3) + if _, err := io.ReadFull(client, auth); err != nil { + t.Fatalf("ReadFull(auth) error = %v", err) + } + if !bytes.Equal(auth, []byte{5, 1, 0}) { + t.Fatalf("auth request = %v", auth) + } + if _, err := client.Write([]byte{5, 0}); err != nil { + t.Fatalf("Write(auth resp) error = %v", err) + } + + req := make([]byte, 18) + if _, err := io.ReadFull(client, req); err != nil { + t.Fatalf("ReadFull(connect req) error = %v", err) + } + if req[0] != 5 || req[1] != 1 || req[3] != 3 || req[4] != byte(len("example.com")) { + t.Fatalf("connect request header = %v", req[:5]) + } + if string(req[5:16]) != "example.com" { + t.Fatalf("connect request addr = %q", req[5:16]) + } + if req[16] != 0x01 || req[17] != 0xbb { + t.Fatalf("connect request port bytes = %v", req[16:18]) + } + if _, err := client.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + t.Fatalf("Write(connect resp) error = %v", err) + } + + if err := <-done; err != nil { + t.Fatalf("socks5Connect() error = %v", err) + } +} + +func TestSocks5ConnectErrors(t *testing.T) { + s := &Server{} + + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + done := make(chan error, 1) + go func() { + done <- s.socks5Connect(server, "example.com", 443) + }() + + auth := make([]byte, 3) + if _, err := io.ReadFull(client, auth); err != nil { + t.Fatalf("ReadFull(auth) error = %v", err) + } + if _, err := client.Write([]byte{5, 1}); err != nil { + t.Fatalf("Write(auth resp) error = %v", err) + } + if err := <-done; !errors.Is(err, ErrSocks5AuthFailed) { + t.Fatalf("socks5Connect() error = %v, want %v", err, ErrSocks5AuthFailed) + } + + server2, client2 := net.Pipe() + defer func() { + _ = server2.Close() + _ = client2.Close() + }() + + done = make(chan error, 1) + go func() { + done <- s.socks5Connect(server2, "example.com", 443) + }() + + if _, err := io.ReadFull(client2, auth); err != nil { + t.Fatalf("ReadFull(auth2) error = %v", err) + } + if _, err := client2.Write([]byte{5, 0}); err != nil { + t.Fatalf("Write(auth2 resp) error = %v", err) + } + + req := make([]byte, 18) + if _, err := io.ReadFull(client2, req); err != nil { + t.Fatalf("ReadFull(req2) error = %v", err) + } + if _, err := client2.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + t.Fatalf("Write(connect2 resp) error = %v", err) + } + if err := <-done; !errors.Is(err, ErrSocks5ConnectFailed) { + t.Fatalf("socks5Connect() error = %v, want %v", err, ErrSocks5ConnectFailed) + } +} + +func TestSetupResolver(t *testing.T) { + s := &Server{dnsServer: "127.0.0.1:53"} + s.setupResolver() + if s.resolver == nil || !s.resolver.PreferGo || s.resolver.Dial == nil { + t.Fatalf("setupResolver() = %+v", s.resolver) + } +} + +func TestOnDataWithNilConn(t *testing.T) { + s := &Server{} + s.onData([]byte("ignored")) +} + +type serverLinkStub struct { + closed bool +} + +func (s *serverLinkStub) Connect(context.Context) error { return nil } +func (s *serverLinkStub) Send([]byte) error { return nil } +func (s *serverLinkStub) Close() error { s.closed = true; return nil } +func (s *serverLinkStub) SetReconnectCallback(func()) {} +func (s *serverLinkStub) SetShouldReconnect(func() bool) {} +func (s *serverLinkStub) SetEndedCallback(func(string)) {} +func (s *serverLinkStub) WatchConnection(context.Context) {} +func (s *serverLinkStub) CanSend() bool { return true } + +func TestShutdownClosesLinkAndConn(t *testing.T) { + cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + ln := &serverLinkStub{} + s := &Server{ + ln: ln, + cipher: cipher, + conn: muxconn.New(ln, cipher), + } + s.shutdown() + if !ln.closed { + t.Fatal("shutdown() did not close link") + } +} + +func TestDialWithoutProxy(t *testing.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer func() { _ = ln.Close() }() + + done := make(chan struct{}) + go func() { + conn, err := ln.Accept() + if err == nil { + _ = conn.Close() + close(done) + } + }() + + tcpAddr := ln.Addr().(*net.TCPAddr) + s := &Server{resolver: net.DefaultResolver} + conn, err := s.dial(ConnectRequest{Addr: "127.0.0.1", Port: tcpAddr.Port}) + if err != nil { + t.Fatalf("dial() error = %v", err) + } + _ = conn.Close() + <-done +} + +func TestDialProxyError(t *testing.T) { + s := &Server{socksProxyAddr: "127.0.0.1", socksProxyPort: 1} + if _, err := s.dial(ConnectRequest{Addr: "example.com", Port: 443}); err == nil || !strings.Contains(err.Error(), "failed to dial proxy") { + t.Fatalf("dial() error = %v", err) + } +} + +func TestSocks5ConnectTruncatesLongDomain(t *testing.T) { + s := &Server{} + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + longHost := strings.Repeat("a", 300) + done := make(chan error, 1) + go func() { + done <- s.socks5Connect(server, longHost, 443) + }() + + auth := make([]byte, 3) + if _, err := io.ReadFull(client, auth); err != nil { + t.Fatalf("ReadFull(auth) error = %v", err) + } + if _, err := client.Write([]byte{5, 0}); err != nil { + t.Fatalf("Write(auth resp) error = %v", err) + } + + req := make([]byte, 262) + if _, err := io.ReadFull(client, req); err != nil { + t.Fatalf("ReadFull(connect req) error = %v", err) + } + if req[4] != 255 { + t.Fatalf("domain len byte = %d, want 255", req[4]) + } + if _, err := client.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}); err != nil { + t.Fatalf("Write(connect resp) error = %v", err) + } + if err := <-done; err != nil { + t.Fatalf("socks5Connect() error = %v", err) + } +} + +func TestHandleStreamRejectsWrongClientID(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + done := make(chan struct{}) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + (&Server{clientID: "expected"}).handleStream(context.Background(), stream) + } + close(done) + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + req, err := json.Marshal(ConnectRequest{ + Cmd: "connect", + ClientID: "wrong", + Addr: "example.com", + Port: 443, + }) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if _, err := stream.Write(req); err != nil { + t.Fatalf("Write() error = %v", err) + } + <-done +} diff --git a/internal/transport/datachannel/transport_test.go b/internal/transport/datachannel/transport_test.go new file mode 100644 index 0000000..b5a33ea --- /dev/null +++ b/internal/transport/datachannel/transport_test.go @@ -0,0 +1,139 @@ +package datachannel + +import ( + "context" + "errors" + "testing" + + "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/transport" +) + +type stubSession struct { + stream carrier.ByteStream + streamErr error +} + +func (s *stubSession) Capabilities() carrier.Capabilities { + return carrier.Capabilities{ByteStream: true} +} +func (s *stubSession) OpenByteStream() (carrier.ByteStream, error) { + if s.streamErr != nil { + return nil, s.streamErr + } + return s.stream, nil +} + +type nonByteStreamSession struct{} + +func (s *nonByteStreamSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} } + +type stubByteStream struct { + connectErr error + sendErr error + closeErr error + canSend bool + + connectCalled bool + sent []byte + watched bool + reconnectCB func() + shouldFn func() bool + endedCB func(string) +} + +func (s *stubByteStream) Connect(context.Context) error { s.connectCalled = true; return s.connectErr } +func (s *stubByteStream) Send(data []byte) error { + s.sent = append([]byte(nil), data...) + return s.sendErr +} +func (s *stubByteStream) Close() error { return s.closeErr } +func (s *stubByteStream) SetReconnectCallback(cb func()) { s.reconnectCB = cb } +func (s *stubByteStream) SetShouldReconnect(fn func() bool) { s.shouldFn = fn } +func (s *stubByteStream) SetEndedCallback(cb func(string)) { s.endedCB = cb } +func (s *stubByteStream) WatchConnection(context.Context) { s.watched = true } +func (s *stubByteStream) CanSend() bool { return s.canSend } + +func TestNewAndFeatures(t *testing.T) { + stream := &stubByteStream{canSend: true} + carrier.Register("datachannel-test-new-and-features", func(context.Context, carrier.Config) (carrier.Session, error) { + return &stubSession{stream: stream}, nil + }) + + tr, err := New(context.Background(), transport.Config{Carrier: "datachannel-test-new-and-features"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + if err := tr.Connect(context.Background()); err != nil { + t.Fatalf("Connect() error = %v", err) + } + if !stream.connectCalled { + t.Fatal("Connect() was not forwarded") + } + if err := tr.Send([]byte("payload")); err != nil { + t.Fatalf("Send() error = %v", err) + } + if string(stream.sent) != "payload" { + t.Fatalf("Send() forwarded %q, want payload", stream.sent) + } + tr.SetReconnectCallback(func() {}) + tr.SetShouldReconnect(func() bool { return true }) + tr.SetEndedCallback(func(string) {}) + tr.WatchConnection(context.Background()) + if stream.reconnectCB == nil || stream.shouldFn == nil || stream.endedCB == nil || !stream.watched { + t.Fatal("callbacks/watch were not forwarded") + } + if !tr.CanSend() { + t.Fatal("CanSend() = false, want true") + } + + features := tr.Features() + if !features.Reliable || !features.Ordered || !features.MessageOriented || features.MaxPayloadSize != defaultMaxPayloadSize { + t.Fatalf("Features() = %+v", features) + } + if err := tr.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } +} + +func TestNewErrorPaths(t *testing.T) { + carrier.Register("datachannel-fail-create", func(context.Context, carrier.Config) (carrier.Session, error) { + return nil, errors.New("boom") + }) + if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}); err == nil || err.Error() != "create provider transport: boom" { + t.Fatalf("New() error = %v", err) + } + + carrier.Register("datachannel-no-stream", func(context.Context, carrier.Config) (carrier.Session, error) { + return &nonByteStreamSession{}, nil + }) + if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}); !errors.Is(err, carrier.ErrByteStreamUnsupported) { + t.Fatalf("New() error = %v, want %v", err, carrier.ErrByteStreamUnsupported) + } + + carrier.Register("datachannel-open-stream-fails", func(context.Context, carrier.Config) (carrier.Session, error) { + return &stubSession{streamErr: errors.New("open boom")}, nil + }) + if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-open-stream-fails"}); err == nil || err.Error() != "open byte stream: open boom" { + t.Fatalf("New() error = %v", err) + } +} + +func TestStreamTransportWrapsErrors(t *testing.T) { + tr := &streamTransport{stream: &stubByteStream{ + connectErr: errors.New("connect boom"), + sendErr: errors.New("send boom"), + closeErr: errors.New("close boom"), + }} + + if err := tr.Connect(context.Background()); err == nil || err.Error() != "stream connect: connect boom" { + t.Fatalf("Connect() error = %v", err) + } + if err := tr.Send([]byte("x")); err == nil || err.Error() != "stream send: send boom" { + t.Fatalf("Send() error = %v", err) + } + if err := tr.Close(); err == nil || err.Error() != "stream close: close boom" { + t.Fatalf("Close() error = %v", err) + } +} diff --git a/internal/transport/seichannel/frame_extra_test.go b/internal/transport/seichannel/frame_extra_test.go new file mode 100644 index 0000000..206e403 --- /dev/null +++ b/internal/transport/seichannel/frame_extra_test.go @@ -0,0 +1,84 @@ +package seichannel + +import ( + "bytes" + "errors" + "testing" +) + +func TestFragmentPayload(t *testing.T) { + frags := fragmentPayload([]byte("abcdef"), 2) + want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")} + if len(frags) != len(want) { + t.Fatalf("fragment count = %d, want %d", len(frags), len(want)) + } + for i := range frags { + if !bytes.Equal(frags[i], want[i]) { + t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i]) + } + } + + empty := fragmentPayload(nil, 10) + if len(empty) != 1 || len(empty[0]) != 0 { + t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty) + } +} + +func TestDecodeTransportFrameErrorsAndAck(t *testing.T) { + tests := []struct { + data []byte + want error + }{ + {data: []byte{1, 2, 3}, want: ErrFrameTooShort}, + {data: []byte{0, 0, 0, 0, protocolVersion, frameTypeAck}, want: ErrUnexpectedMagic}, + {data: []byte{0x4f, 0x56, 0x43, 0x31, 9, frameTypeAck}, want: ErrUnexpectedVersion}, + {data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, frameTypeAck}, want: ErrAckTooShort}, + {data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, frameTypeData}, want: ErrDataTooShort}, + {data: []byte{0x4f, 0x56, 0x43, 0x31, protocolVersion, 99}, want: ErrUnexpectedFrameType}, + } + for _, tt := range tests { + if _, err := decodeTransportFrame(tt.data); !errors.Is(err, tt.want) { + t.Fatalf("decodeTransportFrame(%v) error = %v, want %v", tt.data, err, tt.want) + } + } + + ack, err := decodeTransportFrame(encodeAckFrame(7, 0x1234)) + if err != nil { + t.Fatalf("decode ack error = %v", err) + } + if ack.typ != frameTypeAck || ack.seq != 7 || ack.crc != 0x1234 { + t.Fatalf("ack = %+v", ack) + } +} + +func TestSEIHelpersAndErrors(t *testing.T) { + escaped := escapeRBSP([]byte{0, 0, 1, 0, 0, 2, 3}) + if !bytes.Equal(unescapeRBSP(escaped), []byte{0, 0, 1, 0, 0, 2, 3}) { + t.Fatalf("unescapeRBSP(escapeRBSP()) = %v", unescapeRBSP(escaped)) + } + + value := appendSEIValue(nil, 300) + got, next, err := consumeSEIValue(value, 0) + if err != nil || got != 300 || next != len(value) { + t.Fatalf("consumeSEIValue() = (%d, %d, %v), want 300", got, next, err) + } + if _, _, err := consumeSEIValue([]byte{0xff}, 0); !errors.Is(err, ErrSEIValueTruncated) { + t.Fatalf("consumeSEIValue() error = %v, want %v", err, ErrSEIValueTruncated) + } + + rbsp := appendSEIValue(nil, 5) + rbsp = append(rbsp, appendSEIValue(nil, len(videoSEIUUID)+5)...) + rbsp = append(rbsp, videoSEIUUID[:]...) + rbsp = append(rbsp, []byte{1, 2}...) + if _, err := extractTransportSEI(rbsp); !errors.Is(err, ErrSEIPayloadTruncated) { + t.Fatalf("extractTransportSEI() error = %v, want %v", err, ErrSEIPayloadTruncated) + } + + payloads, err := extractTransportSEI([]byte{4, 1, 0, 0x80}) + if err != nil { + t.Fatalf("extractTransportSEI(non-transport) error = %v", err) + } + if len(payloads) != 0 { + t.Fatalf("extractTransportSEI(non-transport) = %v, want none", payloads) + } +} diff --git a/internal/transport/seichannel/inbound_test.go b/internal/transport/seichannel/inbound_test.go new file mode 100644 index 0000000..31b54ae --- /dev/null +++ b/internal/transport/seichannel/inbound_test.go @@ -0,0 +1,111 @@ +package seichannel + +import ( + "bytes" + "hash/crc32" + "testing" +) + +func TestInboundAssemblyAndAck(t *testing.T) { + var got []byte + tr := &streamTransport{ + onData: func(data []byte) { got = append([]byte(nil), data...) }, + outboundAck: make(chan []byte, 4), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + } + + payload := []byte("hello world") + crc := crc32.ChecksumIEEE(payload) + tr.handleInboundFrame(transportFrame{ + typ: frameTypeData, + seq: 1, + crc: crc, + totalLen: uint32(len(payload)), + fragIdx: 1, + fragTotal: 2, + payload: []byte(" world"), + }) + if len(got) != 0 { + t.Fatalf("onData called before message complete: %q", got) + } + + tr.handleInboundFrame(transportFrame{ + typ: frameTypeData, + seq: 1, + crc: crc, + totalLen: uint32(len(payload)), + fragIdx: 0, + fragTotal: 2, + payload: []byte("hello"), + }) + if !bytes.Equal(got, payload) { + t.Fatalf("assembled payload = %q, want %q", got, payload) + } + select { + case ack := <-tr.outboundAck: + frame, err := decodeTransportFrame(ack) + if err != nil || frame.typ != frameTypeAck || frame.seq != 1 || frame.crc != crc { + t.Fatalf("ack frame = %+v err=%v", frame, err) + } + default: + t.Fatal("handleInboundFrame() did not enqueue ack") + } + + got = nil + tr.handleInboundFrame(transportFrame{ + typ: frameTypeData, + seq: 1, + crc: crc, + totalLen: uint32(len(payload)), + fragIdx: 0, + fragTotal: 2, + payload: []byte("hello"), + }) + if got != nil { + t.Fatalf("duplicate delivered payload again: %q", got) + } +} + +func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { + tr := &streamTransport{ + outboundAck: make(chan []byte, 2), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + } + + msg, complete := tr.upsertInbound(transportFrame{ + seq: 1, + crc: 1, + totalLen: 3, + fragIdx: 3, + fragTotal: 1, + payload: []byte("bad"), + }) + if msg != nil || complete { + t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete) + } + + called := false + tr.onData = func([]byte) { called = true } + tr.handleInboundFrame(transportFrame{ + seq: 2, + crc: 123, + totalLen: 3, + fragIdx: 0, + fragTotal: 1, + payload: []byte("abc"), + }) + if called { + t.Fatal("handleInboundFrame() delivered payload with bad crc") + } + + msg = &inboundMessage{ + totalLen: 3, + crc: crc32.ChecksumIEEE([]byte("abcdef")), + frags: [][]byte{[]byte("abc"), []byte("def")}, + } + if got := tr.assembleMessage(msg); string(got) != "abc" { + t.Fatalf("assembleMessage() = %q, want abc", got) + } +} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go new file mode 100644 index 0000000..6330b6a --- /dev/null +++ b/internal/transport/transport_test.go @@ -0,0 +1,72 @@ +package transport + +import ( + "context" + "errors" + "reflect" + "testing" +) + +type stubTransport struct{} + +func (s *stubTransport) Connect(context.Context) error { return nil } +func (s *stubTransport) Send([]byte) error { return nil } +func (s *stubTransport) Close() error { return nil } +func (s *stubTransport) SetReconnectCallback(func()) {} +func (s *stubTransport) SetShouldReconnect(func() bool) {} +func (s *stubTransport) SetEndedCallback(func(string)) {} +func (s *stubTransport) WatchConnection(context.Context) {} +func (s *stubTransport) CanSend() bool { return true } +func (s *stubTransport) Features() Features { return Features{Reliable: true} } + +func snapshotTransportRegistry() map[string]Factory { + out := make(map[string]Factory, len(registry)) + for k, v := range registry { + out[k] = v + } + return out +} + +func restoreTransportRegistry(src map[string]Factory) { + registry = make(map[string]Factory, len(src)) + for k, v := range src { + registry[k] = v + } +} + +func TestNewAndAvailable(t *testing.T) { + old := snapshotTransportRegistry() + t.Cleanup(func() { restoreTransportRegistry(old) }) + + called := false + Register("test-transport", func(_ context.Context, cfg Config) (Transport, error) { + called = cfg.ClientID == "client-1" + return &stubTransport{}, nil + }) + + got, err := New(context.Background(), "test-transport", Config{ClientID: "client-1"}) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if !called { + t.Fatal("factory did not receive config") + } + if _, ok := got.(*stubTransport); !ok { + t.Fatalf("New() returned %T, want *stubTransport", got) + } + + if !reflect.DeepEqual(Available(), []string{"test-transport"}) { + t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-transport"}) + } +} + +func TestNewReturnsErrTransportNotFound(t *testing.T) { + old := snapshotTransportRegistry() + t.Cleanup(func() { restoreTransportRegistry(old) }) + registry = map[string]Factory{} + + _, err := New(context.Background(), "missing", Config{}) + if !errors.Is(err, ErrTransportNotFound) { + t.Fatalf("New() error = %v, want %v", err, ErrTransportNotFound) + } +} diff --git a/internal/transport/videochannel/frame_extra_test.go b/internal/transport/videochannel/frame_extra_test.go new file mode 100644 index 0000000..d782d94 --- /dev/null +++ b/internal/transport/videochannel/frame_extra_test.go @@ -0,0 +1,139 @@ +package videochannel + +import ( + "bytes" + "errors" + "io" + "slices" + "strings" + "testing" + + "github.com/pion/webrtc/v4" +) + +func TestFragmentPayload(t *testing.T) { + frags := fragmentPayload([]byte("abcdef"), 2) + want := [][]byte{[]byte("ab"), []byte("cd"), []byte("ef")} + if len(frags) != len(want) { + t.Fatalf("fragment count = %d, want %d", len(frags), len(want)) + } + for i := range frags { + if !bytes.Equal(frags[i], want[i]) { + t.Fatalf("frag %d = %q, want %q", i, frags[i], want[i]) + } + } + + empty := fragmentPayload(nil, 10) + if len(empty) != 1 || len(empty[0]) != 0 { + t.Fatalf("fragmentPayload(nil) = %#v, want one empty frag", empty) + } +} + +func TestDecodeTransportFrameErrorsAndAck(t *testing.T) { + tests := []struct { + data []byte + want error + }{ + {data: []byte{1, 2, 3}, want: ErrFrameTooShort}, + {data: []byte{0, 0, 0, 0, protocolVersion, frameTypeAck}, want: ErrUnexpectedMagic}, + {data: []byte{0x4f, 0x56, 0x56, 0x32, 9, frameTypeAck}, want: ErrUnexpectedVersion}, + {data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, frameTypeAck}, want: ErrAckTooShort}, + {data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, frameTypeData}, want: ErrDataTooShort}, + {data: []byte{0x4f, 0x56, 0x56, 0x32, protocolVersion, 99}, want: ErrUnexpectedFrameType}, + } + for _, tt := range tests { + if _, err := decodeTransportFrame(tt.data); !errors.Is(err, tt.want) { + t.Fatalf("decodeTransportFrame(%v) error = %v, want %v", tt.data, err, tt.want) + } + } + + ack, err := decodeTransportFrame(encodeAckFrame(7, 0x1234)) + if err != nil { + t.Fatalf("decode ack error = %v", err) + } + if ack.typ != frameTypeAck || ack.seq != 7 || ack.crc != 0x1234 { + t.Fatalf("ack = %+v", ack) + } +} + +func TestCodecSpecsAndArgs(t *testing.T) { + for _, mime := range []string{webrtc.MimeTypeH264, webrtc.MimeTypeVP8, webrtc.MimeTypeVP9} { + spec, ok := codecSpecForMime(mime) + if !ok { + t.Fatalf("codecSpecForMime(%q) ok = false", mime) + } + if spec.mimeType != mime || spec.depacketizer == nil || spec.capability.ClockRate != 90000 { + t.Fatalf("codec spec = %+v", spec) + } + } + if _, ok := codecSpecForMime("video/unknown"); ok { + t.Fatal("codecSpecForMime() accepted unknown mime") + } + + if got := resolveEncoderCodec(h264CodecSpec(), "nvenc"); got != "h264_nvenc" { + t.Fatalf("resolveEncoderCodec(h264,nvenc) = %q", got) + } + if got := resolveEncoderCodec(vp8CodecSpec(), "none"); got != "libvpx" { + t.Fatalf("resolveEncoderCodec(vp8,none) = %q", got) + } + + args := buildEncoderArgs(vp8CodecSpec(), "vp8_nvenc", 320, 240, 30, "1M") + for _, want := range []string{"-video_size", "320x240", "-framerate", "30", "vp8_nvenc", "-b:v", "1M", "ivf"} { + if !slices.Contains(args, want) { + t.Fatalf("buildEncoderArgs() = %v, missing %q", args, want) + } + } + h264Args := buildEncoderArgs(h264CodecSpec(), "libx264", 320, 240, 30, "1M") + if h264Args[len(h264Args)-2] != "h264" { + t.Fatalf("h264 encoder args = %v", h264Args) + } +} + +type shortWriter struct { + writes int +} + +func (w *shortWriter) Write(p []byte) (int, error) { + w.writes++ + if w.writes == 1 { + return 1, nil + } + return len(p), nil +} + +type errWriter struct{} + +func (w errWriter) Write([]byte) (int, error) { return 0, io.ErrClosedPipe } + +func TestIVFWritersAndWithStderr(t *testing.T) { + var buf bytes.Buffer + if err := writeIVFHeader(&buf, "VP80", 320, 240, 30); err != nil { + t.Fatalf("writeIVFHeader() error = %v", err) + } + if buf.Len() != 32 || string(buf.Bytes()[:4]) != "DKIF" { + t.Fatalf("IVF header = %v", buf.Bytes()) + } + + buf.Reset() + if err := writeIVFFrame(&buf, 3, []byte("abc")); err != nil { + t.Fatalf("writeIVFFrame() error = %v", err) + } + if buf.Len() != 15 { + t.Fatalf("IVF frame len = %d, want 15", buf.Len()) + } + + if err := writeAll(&shortWriter{}, []byte("abc")); err != nil { + t.Fatalf("writeAll(shortWriter) error = %v", err) + } + if err := writeAll(errWriter{}, []byte("abc")); err == nil || !strings.Contains(err.Error(), "write:") { + t.Fatalf("writeAll(errWriter) error = %v", err) + } + + baseErr := errors.New("base") + if got := withStderr(baseErr, bytes.NewBufferString(" details \n")); got == nil || got.Error() != "base: details" { + t.Fatalf("withStderr() = %v", got) + } + if got := withStderr(nil, bytes.NewBufferString("details")); got != nil { + t.Fatalf("withStderr(nil) = %v", got) + } +} diff --git a/internal/transport/videochannel/inbound_test.go b/internal/transport/videochannel/inbound_test.go new file mode 100644 index 0000000..6a76c72 --- /dev/null +++ b/internal/transport/videochannel/inbound_test.go @@ -0,0 +1,97 @@ +package videochannel + +import ( + "bytes" + "hash/crc32" + "testing" +) + +func TestInboundAssemblyAndAck(t *testing.T) { + var got []byte + tr := &streamTransport{ + onData: func(data []byte) { got = append([]byte(nil), data...) }, + outboundAck: make(chan []byte, 4), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + } + + payload := []byte("hello video") + crc := crc32.ChecksumIEEE(payload) + tr.handleInboundFrame(transportFrame{ + typ: frameTypeData, + seq: 1, + crc: crc, + totalLen: uint32(len(payload)), + fragIdx: 1, + fragTotal: 2, + payload: []byte(" video"), + }) + if len(got) != 0 { + t.Fatalf("onData called before message complete: %q", got) + } + + tr.handleInboundFrame(transportFrame{ + typ: frameTypeData, + seq: 1, + crc: crc, + totalLen: uint32(len(payload)), + fragIdx: 0, + fragTotal: 2, + payload: []byte("hello"), + }) + if !bytes.Equal(got, payload) { + t.Fatalf("assembled payload = %q, want %q", got, payload) + } + select { + case ack := <-tr.outboundAck: + frame, err := decodeTransportFrame(ack) + if err != nil || frame.typ != frameTypeAck || frame.seq != 1 || frame.crc != crc { + t.Fatalf("ack frame = %+v err=%v", frame, err) + } + default: + t.Fatal("handleInboundFrame() did not enqueue ack") + } +} + +func TestInboundRejectsBadFragmentsAndCRC(t *testing.T) { + tr := &streamTransport{ + outboundAck: make(chan []byte, 2), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + } + + msg, complete := tr.upsertInbound(transportFrame{ + seq: 1, + crc: 1, + totalLen: 3, + fragIdx: 3, + fragTotal: 1, + payload: []byte("bad"), + }) + if msg != nil || complete { + t.Fatalf("upsertInbound(out of range) = (%v, %v), want nil false", msg, complete) + } + + called := false + tr.onData = func([]byte) { called = true } + tr.handleInboundFrame(transportFrame{ + seq: 2, + crc: 123, + totalLen: 3, + fragIdx: 0, + fragTotal: 1, + payload: []byte("abc"), + }) + if called { + t.Fatal("handleInboundFrame() delivered payload with bad crc") + } + + msg = &inboundMessage{ + totalLen: 3, + crc: crc32.ChecksumIEEE([]byte("abcdef")), + frags: [][]byte{[]byte("abc"), []byte("def")}, + } + if got := tr.assembleMessage(msg); string(got) != "abc" { + t.Fatalf("assembleMessage() = %q, want abc", got) + } +} diff --git a/internal/transport/vp8channel/kcpconn_test.go b/internal/transport/vp8channel/kcpconn_test.go new file mode 100644 index 0000000..e7d8d20 --- /dev/null +++ b/internal/transport/vp8channel/kcpconn_test.go @@ -0,0 +1,71 @@ +package vp8channel + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +func TestKCPConnReadWriteDeadlinesAndClose(t *testing.T) { + out := make(chan []byte, 1) + hdr := testEpochHdr(9) + conn := newKCPConn(out, 1, hdr) + + if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetDeadline() error = %v", err) + } + if conn.LocalAddr().String() != "127.0.0.1:1" { + t.Fatalf("LocalAddr() = %v", conn.LocalAddr()) + } + + n, err := conn.WriteTo([]byte("payload"), nil) + if err != nil || n != len("payload") { + t.Fatalf("WriteTo() = (%d, %v), want payload length", n, err) + } + wire := <-out + if !bytes.Equal(wire[:epochHdrLen], hdr[:]) || string(wire[epochHdrLen:]) != "payload" { + t.Fatalf("wire packet = %v", wire) + } + + conn.deliver([]byte("incoming")) + buf := make([]byte, 64) + n, addr, err := conn.ReadFrom(buf) + if err != nil || addr == nil || string(buf[:n]) != "incoming" { + t.Fatalf("ReadFrom() = (%d, %v, %v), payload %q", n, addr, err, buf[:n]) + } + + if err := conn.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if _, _, err := conn.ReadFrom(buf); !errors.Is(err, net.ErrClosed) { + t.Fatalf("ReadFrom() error = %v, want net.ErrClosed", err) + } + + closedWrite := newKCPConn(make(chan []byte), 1, hdr) + _ = closedWrite.Close() + if _, err := closedWrite.WriteTo([]byte("x"), nil); !errors.Is(err, net.ErrClosed) { + t.Fatalf("WriteTo() error = %v, want net.ErrClosed", err) + } +} + +func TestKCPConnTimeouts(t *testing.T) { + conn := newKCPConn(make(chan []byte), 1, testEpochHdr(1)) + if err := conn.SetReadDeadline(time.Now().Add(-time.Millisecond)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + buf := make([]byte, 4) + if _, _, err := conn.ReadFrom(buf); err == nil { + t.Fatal("ReadFrom() unexpectedly succeeded") + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() || !netErr.Temporary() { + t.Fatalf("ReadFrom() error = %T %v, want timeout net.Error", err, err) + } + + if err := conn.SetWriteDeadline(time.Now().Add(-time.Millisecond)); err != nil { + t.Fatalf("SetWriteDeadline() error = %v", err) + } + if _, err := conn.WriteTo([]byte("x"), nil); err == nil { + t.Fatal("WriteTo() unexpectedly succeeded") + } +} diff --git a/mobile/mobile_test.go b/mobile/mobile_test.go new file mode 100644 index 0000000..72a26ee --- /dev/null +++ b/mobile/mobile_test.go @@ -0,0 +1,206 @@ +package mobile + +import ( + "errors" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/logger" + "github.com/openlibrecommunity/olcrtc/internal/protect" +) + +type testProtector struct { + called int +} + +func (p *testProtector) Protect(fd int) bool { + p.called = fd + return true +} + +type testLogWriter struct { + got string +} + +func (w *testLogWriter) WriteLog(msg string) { + w.got += msg +} + +func resetMobileGlobals(t *testing.T) { + t.Helper() + mu.Lock() + if cancel != nil { + cancel() + } + cancel = nil + done = nil + ready = nil + errRun = nil + defaults = mobileConfig{} + defaultsSet = sync.Once{} + mu.Unlock() + protect.Protector = nil + logger.SetVerbose(false) +} + +func TestProtectorAndLogging(t *testing.T) { + resetMobileGlobals(t) + p := &testProtector{} + SetProtector(p) + if protect.Protector == nil || !protect.Protector(123) || p.called != 123 { + t.Fatal("SetProtector() did not install adapter") + } + SetProtector(nil) + if protect.Protector != nil { + t.Fatal("SetProtector(nil) did not clear protector") + } + + w := &testLogWriter{} + SetLogWriter(w) + log.Print("hello") + if !strings.Contains(w.got, "hello") { + t.Fatalf("log writer got %q, want hello", w.got) + } +} + +func TestDefaultsAndSetters(t *testing.T) { + resetMobileGlobals(t) + + SetTransport("dc") + SetLink("direct") + SetDNS("9.9.9.9:53") + SetVP8Options(-1, 999) + + mu.Lock() + got := defaults + mu.Unlock() + if got.transport != dataTransport || got.link != defaultLink || got.dnsServer != "9.9.9.9:53" || + got.vp8FPS != 1 || got.vp8BatchSize != 64 { + t.Fatalf("defaults = %+v", got) + } + + SetDebug(true) + if !logger.IsVerbose() { + t.Fatal("SetDebug(true) did not enable verbose") + } + SetDebug(false) + if logger.IsVerbose() { + t.Fatal("SetDebug(false) did not disable verbose") + } +} + +func TestNormalizeBuildRoomAndClamp(t *testing.T) { + tests := map[string]string{ + "datachannel": dataTransport, + "data": dataTransport, + "dc": dataTransport, + "vp8channel": defaultTransport, + "vp8": defaultTransport, + "bad": defaultTransport, + } + for in, want := range tests { + if got := normalizeTransport(in); got != want { + t.Fatalf("normalizeTransport(%q) = %q, want %q", in, got, want) + } + } + + if normalizeCarrier(carrierWBStream) != carrierWBStream || normalizeCarrier("jazz") != "jazz" { + t.Fatal("normalizeCarrier() returned unexpected value") + } + + if got := buildRoomURL("telemost", "abc"); got != "https://telemost.yandex.ru/j/abc" { + t.Fatalf("telemost room URL = %q", got) + } + if got := buildRoomURL("jazz", ""); got != "any" { + t.Fatalf("jazz empty room URL = %q", got) + } + if got := buildRoomURL(carrierWBStream, "room"); got != "room" { + t.Fatalf("wbstream room URL = %q", got) + } + + if clamp(0, 1, 10) != 1 || clamp(11, 1, 10) != 10 || clamp(5, 1, 10) != 5 { + t.Fatal("clamp() returned unexpected value") + } +} + +func TestStartValidation(t *testing.T) { + resetMobileGlobals(t) + + if err := startWithConfig("", dataTransport, "room", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errCarrierRequired) { + t.Fatalf("startWithConfig(missing carrier) = %v", err) + } + if err := startWithConfig("telemost", dataTransport, "", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errRoomIDRequired) { + t.Fatalf("startWithConfig(missing room) = %v", err) + } + if err := startWithConfig("jazz", dataTransport, "", "", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errClientIDRequired) { + t.Fatalf("startWithConfig(missing client) = %v", err) + } + if err := startWithConfig("jazz", dataTransport, "", "client", "", 1080, "", "", mobileConfig{}); !errors.Is(err, errKeyHexRequired) { + t.Fatalf("startWithConfig(missing key) = %v", err) + } + + mu.Lock() + cancel = func() {} + mu.Unlock() + if err := startWithConfig("jazz", dataTransport, "", "client", "key", 1080, "", "", mobileConfig{}); !errors.Is(err, errAlreadyRunning) { + t.Fatalf("startWithConfig(running) = %v", err) + } + resetMobileGlobals(t) +} + +func TestWaitReadyStatesAndStop(t *testing.T) { + resetMobileGlobals(t) + + if err := WaitReady(1); !errors.Is(err, errNotRunning) { + t.Fatalf("WaitReady(not running) = %v", err) + } + + mu.Lock() + errRun = errors.New("run failed") + mu.Unlock() + if err := WaitReady(1); err == nil || err.Error() != "run failed" { + t.Fatalf("WaitReady(run err) = %v", err) + } + + mu.Lock() + errRun = nil + ready = make(chan struct{}) + done = make(chan struct{}) + cancel = func() {} + mu.Unlock() + if err := WaitReady(1); !errors.Is(err, errStartTimedOut) { + t.Fatalf("WaitReady(timeout) = %v", err) + } + + mu.Lock() + close(ready) + mu.Unlock() + if err := WaitReady(1); err != nil { + t.Fatalf("WaitReady(ready) error = %v", err) + } + + mu.Lock() + cancel = func() {} + done = make(chan struct{}) + doneCh := done + mu.Unlock() + go func() { + time.Sleep(time.Millisecond) + close(doneCh) + }() + Stop() + mu.Lock() + cancel = nil + mu.Unlock() +} + +func TestLogBridge(t *testing.T) { + w := &testLogWriter{} + n, err := (&logBridge{w: w}).Write([]byte("abc")) + if err != nil || n != 3 || w.got != "abc" { + t.Fatalf("logBridge.Write() = (%d, %v), got %q", n, err, w.got) + } +}