diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 1908aad..75a1bf6 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -17,7 +17,9 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/server" + "github.com/pion/webrtc/v4" ) const testKeyHex = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" @@ -27,13 +29,17 @@ type memorySession struct { } func (s *memorySession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{ByteStream: true} + return carrier.Capabilities{ByteStream: true, VideoTrack: true} } func (s *memorySession) OpenByteStream() (carrier.ByteStream, error) { return s.stream, nil } +func (s *memorySession) OpenVideoTrack() (carrier.VideoTrack, error) { + return s.stream, nil +} + type memoryRoom struct { mu sync.Mutex streams map[*memoryStream]struct{} @@ -100,6 +106,8 @@ type memoryStream struct { closed bool reconnect func() ended func(string) + track webrtc.TrackLocal + trackCB func(*webrtc.TrackRemote, *webrtc.RTPReceiver) pending [][]byte } @@ -183,6 +191,19 @@ func (s *memoryStream) CanSend() bool { return s.isConnected() } +func (s *memoryStream) AddTrack(track webrtc.TrackLocal) error { + s.mu.Lock() + s.track = track + s.mu.Unlock() + return nil +} + +func (s *memoryStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + s.mu.Lock() + s.trackCB = cb + s.mu.Unlock() +} + func (s *memoryStream) isConnected() bool { s.mu.Lock() defer s.mu.Unlock() @@ -225,6 +246,83 @@ func registerMemoryCarrier(t *testing.T) (string, *memoryRoom) { return name, room } +func registerMemoryCarrierAs(t *testing.T, name string) *memoryRoom { + t.Helper() + + room := &memoryRoom{streams: make(map[*memoryStream]struct{})} + carrier.Register(name, func(_ context.Context, cfg carrier.Config) (carrier.Session, error) { + stream := &memoryStream{room: room, onData: cfg.OnData} + room.mu.Lock() + room.streams[stream] = struct{}{} + room.mu.Unlock() + return &memorySession{stream: stream}, nil + }) + return room +} + +func builtInCarrierNames() []string { + return []string{"jazz", "telemost", "wbstream"} +} + +func builtInTransportNames() []string { + return []string{"datachannel", "videochannel", "seichannel", "vp8channel"} +} + +func validSessionConfig(mode, carrierName, transportName string) session.Config { + return session.Config{ + Mode: mode, + Link: "direct", + Transport: transportName, + Carrier: carrierName, + RoomID: "room", + ClientID: "client-1", + KeyHex: testKeyHex, + SOCKSHost: "127.0.0.1", + SOCKSPort: 1080, + DNSServer: "127.0.0.1:53", + VideoWidth: 1080, + VideoHeight: 1080, + VideoFPS: 30, + VideoBitrate: "1M", + VideoHW: "none", + VideoCodec: "tile", + VideoTileModule: 4, + VideoTileRS: 20, + VP8FPS: 60, + VP8BatchSize: 8, + SEIFPS: 30, + SEIBatchSize: 4, + SEIFragmentSize: 512, + SEIAckTimeoutMS: 1500, + } +} + +func validLinkConfig(carrierName, transportName string) link.Config { + cfg := validSessionConfig("cnc", carrierName, transportName) + return link.Config{ + Transport: cfg.Transport, + Carrier: cfg.Carrier, + RoomURL: "room", + ClientID: cfg.ClientID, + Name: "e2e-" + carrierName + "-" + transportName, + DNSServer: cfg.DNSServer, + VideoWidth: cfg.VideoWidth, + VideoHeight: cfg.VideoHeight, + VideoFPS: cfg.VideoFPS, + VideoBitrate: cfg.VideoBitrate, + VideoHW: cfg.VideoHW, + VideoCodec: cfg.VideoCodec, + VideoTileModule: cfg.VideoTileModule, + VideoTileRS: cfg.VideoTileRS, + VP8FPS: cfg.VP8FPS, + VP8BatchSize: cfg.VP8BatchSize, + SEIFPS: cfg.SEIFPS, + SEIBatchSize: cfg.SEIBatchSize, + SEIFragmentSize: cfg.SEIFragmentSize, + SEIAckTimeoutMS: cfg.SEIAckTimeoutMS, + } +} + func startEchoServer(t *testing.T) string { t.Helper() @@ -483,6 +581,81 @@ func connectViaSOCKSExpectFailure(t *testing.T, socksAddr, targetAddr string) [] return reply } +func TestBuiltInProviderTransportMatrixValidates(t *testing.T) { + session.RegisterDefaults() + + for _, mode := range []string{"srv", "cnc"} { + t.Run(mode, func(t *testing.T) { + for _, carrierName := range builtInCarrierNames() { + t.Run(carrierName, func(t *testing.T) { + for _, transportName := range builtInTransportNames() { + t.Run(transportName, func(t *testing.T) { + cfg := validSessionConfig(mode, carrierName, transportName) + if err := session.Validate(cfg); err != nil { + t.Fatalf("Validate() error = %v", err) + } + }) + } + }) + } + }) + } +} + +func TestDirectLinkCreatesAllProviderTransportCombinations(t *testing.T) { + session.RegisterDefaults() + + for _, carrierName := range builtInCarrierNames() { + registerMemoryCarrierAs(t, carrierName) + } + + for _, carrierName := range builtInCarrierNames() { + t.Run(carrierName, func(t *testing.T) { + for _, transportName := range builtInTransportNames() { + t.Run(transportName, func(t *testing.T) { + ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName)) + if err != nil { + t.Fatalf("link.New() error = %v", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + }) + } + }) + } +} + +func TestDirectLinkConnectsFastProviderTransportMatrix(t *testing.T) { + session.RegisterDefaults() + + for _, carrierName := range builtInCarrierNames() { + registerMemoryCarrierAs(t, carrierName) + } + + for _, carrierName := range builtInCarrierNames() { + t.Run(carrierName, func(t *testing.T) { + for _, transportName := range []string{"datachannel", "seichannel", "vp8channel"} { + t.Run(transportName, func(t *testing.T) { + ln, err := link.New(context.Background(), "direct", validLinkConfig(carrierName, transportName)) + if err != nil { + t.Fatalf("link.New() error = %v", err) + } + if err := ln.Connect(context.Background()); err != nil { + t.Fatalf("Connect() error = %v", err) + } + if !ln.CanSend() { + t.Fatal("CanSend() = false, want true") + } + if err := ln.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + }) + } + }) + } +} + func TestClientServerSOCKSTunnelOverMemoryDatachannel(t *testing.T) { echoAddr := startEchoServer(t) rt := startTunnel(t, "client-1", "client-1")