diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 665d0cc..6ef7d23 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -11,11 +11,10 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/auth" - "github.com/openlibrecommunity/olcrtc/internal/carrier" - "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/control" "github.com/openlibrecommunity/olcrtc/internal/crypto" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/names" "github.com/openlibrecommunity/olcrtc/internal/server" @@ -191,7 +190,7 @@ type Config struct { // RegisterDefaults registers built-in carriers and transports. func RegisterDefaults() { - builtin.Register() + enginebuiltin.RegisterDefaults() transport.Register("datachannel", datachannel.New) transport.Register("videochannel", videochannel.New) transport.Register("seichannel", seichannel.New) @@ -352,8 +351,8 @@ func validateAuth(cfg Config) error { if cfg.Auth == "" { return ErrAuthRequired } - if !slices.Contains(carrier.Available(), cfg.Auth) { - return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Auth, carrier.Available()) + if !slices.Contains(enginebuiltin.Available(), cfg.Auth) { + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Auth, enginebuiltin.Available()) } return nil } @@ -724,8 +723,8 @@ func ValidateGen(cfg Config) error { if cfg.Auth == "" { return ErrAuthRequired } - if !slices.Contains(carrier.Available(), cfg.Auth) { - return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Auth, carrier.Available()) + if !slices.Contains(enginebuiltin.Available(), cfg.Auth) { + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Auth, enginebuiltin.Available()) } if cfg.DNSServer == "" { return ErrDNSServerRequired diff --git a/internal/carrier/builtin/engine_adapter.go b/internal/carrier/builtin/engine_adapter.go deleted file mode 100644 index 981d72d..0000000 --- a/internal/carrier/builtin/engine_adapter.go +++ /dev/null @@ -1,187 +0,0 @@ -package builtin - -import ( - "context" - "errors" - "fmt" - - "github.com/openlibrecommunity/olcrtc/internal/auth" - "github.com/openlibrecommunity/olcrtc/internal/carrier" - "github.com/openlibrecommunity/olcrtc/internal/engine" - "github.com/pion/webrtc/v4" -) - -// registerDirect registers a carrier that skips auth entirely — the caller -// supplies the engine name, SFU URL, and access token directly via -// carrier.Config.Engine / carrier.Config.URL / carrier.Config.Token. -func registerDirect(carrierName string) { - carrier.Register(carrierName, func(ctx context.Context, cfg carrier.Config) (carrier.Session, error) { - engineName := cfg.Engine - if engineName == "" { - engineName = "livekit" - } - sess, err := engine.New(ctx, engineName, engine.Config{ - URL: cfg.URL, - Token: cfg.Token, - Name: cfg.Name, - OnData: cfg.OnData, - DNSServer: cfg.DNSServer, - ProxyAddr: cfg.ProxyAddr, - ProxyPort: cfg.ProxyPort, - }) - if err != nil { - return nil, fmt.Errorf("engine new: %w", err) - } - return &engineSession{session: sess}, nil - }) -} - -// registerEngineAuth registers a carrier name that resolves credentials -// through an auth provider and connects via the engine the auth provider -// reports. -func registerEngineAuth(carrierName string, authProvider auth.Provider) { - carrier.Register(carrierName, func(ctx context.Context, cfg carrier.Config) (carrier.Session, error) { - authCfg := auth.Config{ - RoomURL: cfg.RoomURL, - Name: cfg.Name, - DNSServer: cfg.DNSServer, - ProxyAddr: cfg.ProxyAddr, - ProxyPort: cfg.ProxyPort, - } - creds, err := authProvider.Issue(ctx, authCfg) - if err != nil { - return nil, fmt.Errorf("auth issue: %w", errors.Join(carrier.ErrAuthFailed, err)) - } - - sess, err := engine.New(ctx, authProvider.Engine(), engine.Config{ - URL: creds.URL, - Token: creds.Token, - Name: cfg.Name, - Extra: creds.Extra, - OnData: cfg.OnData, - DNSServer: cfg.DNSServer, - ProxyAddr: cfg.ProxyAddr, - ProxyPort: cfg.ProxyPort, - Refresh: func(ctx context.Context) (engine.Credentials, error) { - fresh, err := authProvider.Issue(ctx, authCfg) - if err != nil { - return engine.Credentials{}, fmt.Errorf("auth refresh: %w", err) - } - return engine.Credentials{URL: fresh.URL, Token: fresh.Token, Extra: fresh.Extra}, nil - }, - }) - if err != nil { - return nil, fmt.Errorf("engine new: %w", err) - } - return &engineSession{session: sess}, nil - }) -} - -type engineSession struct { - session engine.Session -} - -func (s *engineSession) Capabilities() carrier.Capabilities { - caps := s.session.Capabilities() - return carrier.Capabilities{ByteStream: caps.ByteStream, VideoTrack: caps.VideoTrack} -} - -func (s *engineSession) OpenByteStream() (carrier.ByteStream, error) { - if !s.session.Capabilities().ByteStream { - return nil, carrier.ErrByteStreamUnsupported - } - return &engineByteStream{session: s.session}, nil -} - -func (s *engineSession) OpenVideoTrack() (carrier.VideoTrack, error) { - vt, ok := s.session.(engine.VideoTrackCapable) - if !ok { - return nil, carrier.ErrVideoTrackUnsupported - } - return &engineVideoTrack{session: s.session, vt: vt}, nil -} - -type engineByteStream struct { - session engine.Session -} - -func (b *engineByteStream) Connect(ctx context.Context) error { - if err := b.session.Connect(ctx); err != nil { - return fmt.Errorf("connect: %w", err) - } - return nil -} - -func (b *engineByteStream) Send(data []byte) error { - if err := b.session.Send(data); err != nil { - return fmt.Errorf("send: %w", err) - } - return nil -} - -func (b *engineByteStream) Close() error { - if err := b.session.Close(); err != nil { - return fmt.Errorf("close: %w", err) - } - return nil -} - -func (b *engineByteStream) SetReconnectCallback(cb func()) { - b.session.SetReconnectCallback(func(_ *webrtc.DataChannel) { - if cb != nil { - cb() - } - }) -} - -func (b *engineByteStream) SetShouldReconnect(fn func() bool) { b.session.SetShouldReconnect(fn) } -func (b *engineByteStream) SetEndedCallback(cb func(string)) { b.session.SetEndedCallback(cb) } -func (b *engineByteStream) WatchConnection(ctx context.Context) { - b.session.WatchConnection(ctx) -} -func (b *engineByteStream) CanSend() bool { return b.session.CanSend() } - -type engineVideoTrack struct { - session engine.Session - vt engine.VideoTrackCapable -} - -func (v *engineVideoTrack) Connect(ctx context.Context) error { - if err := v.session.Connect(ctx); err != nil { - return fmt.Errorf("connect: %w", err) - } - return nil -} - -func (v *engineVideoTrack) Close() error { - if err := v.session.Close(); err != nil { - return fmt.Errorf("close: %w", err) - } - return nil -} - -func (v *engineVideoTrack) SetReconnectCallback(cb func()) { - v.session.SetReconnectCallback(func(_ *webrtc.DataChannel) { - if cb != nil { - cb() - } - }) -} - -func (v *engineVideoTrack) SetShouldReconnect(fn func() bool) { v.session.SetShouldReconnect(fn) } -func (v *engineVideoTrack) SetEndedCallback(cb func(string)) { v.session.SetEndedCallback(cb) } -func (v *engineVideoTrack) WatchConnection(ctx context.Context) { - v.session.WatchConnection(ctx) -} -func (v *engineVideoTrack) CanSend() bool { return v.session.CanSend() } - -func (v *engineVideoTrack) AddTrack(track webrtc.TrackLocal) error { - if err := v.vt.AddVideoTrack(track); err != nil { - return fmt.Errorf("add track: %w", err) - } - return nil -} - -func (v *engineVideoTrack) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { - v.vt.SetVideoTrackHandler(cb) -} diff --git a/internal/carrier/builtin/register.go b/internal/carrier/builtin/register.go deleted file mode 100644 index 50ded3a..0000000 --- a/internal/carrier/builtin/register.go +++ /dev/null @@ -1,22 +0,0 @@ -// Package builtin registers the built-in carrier implementations. -package builtin - -import ( - authJitsi "github.com/openlibrecommunity/olcrtc/internal/auth/jitsi" - authSaluteJazz "github.com/openlibrecommunity/olcrtc/internal/auth/salutejazz" - authTelemost "github.com/openlibrecommunity/olcrtc/internal/auth/telemost" - authWBStream "github.com/openlibrecommunity/olcrtc/internal/auth/wbstream" - _ "github.com/openlibrecommunity/olcrtc/internal/engine/goolom" // engine registration via init - _ "github.com/openlibrecommunity/olcrtc/internal/engine/jitsi" // engine registration via init - _ "github.com/openlibrecommunity/olcrtc/internal/engine/livekit" // engine registration via init - _ "github.com/openlibrecommunity/olcrtc/internal/engine/salutejazz" // engine registration via init -) - -// Register wires the built-in carriers into the carrier registry. -func Register() { - registerEngineAuth("wbstream", authWBStream.Provider{}) - registerEngineAuth("jazz", authSaluteJazz.Provider{}) - registerEngineAuth("telemost", authTelemost.Provider{}) - registerEngineAuth("jitsi", authJitsi.Provider{}) - registerDirect("none") -} diff --git a/internal/carrier/builtin/register_test.go b/internal/carrier/builtin/register_test.go deleted file mode 100644 index 633d8d3..0000000 --- a/internal/carrier/builtin/register_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package builtin - -import ( - "slices" - "testing" - - "github.com/openlibrecommunity/olcrtc/internal/carrier" -) - -func TestRegister(t *testing.T) { - Register() - available := carrier.Available() - for _, want := range []string{"jazz", "telemost", "wbstream"} { - if !slices.Contains(available, want) { - t.Fatalf("Available() = %v, missing %q", available, want) - } - } -} diff --git a/internal/carrier/bytestream.go b/internal/carrier/bytestream.go deleted file mode 100644 index 6803e03..0000000 --- a/internal/carrier/bytestream.go +++ /dev/null @@ -1,32 +0,0 @@ -package carrier - -import ( - "context" - - "github.com/pion/webrtc/v4" -) - -// ByteStream is a carrier capability for bidirectional byte transport. -type ByteStream interface { - Connect(ctx context.Context) error - Send(data []byte) error - Close() error - SetReconnectCallback(cb func()) - SetShouldReconnect(fn func() bool) - SetEndedCallback(cb func(string)) - WatchConnection(ctx context.Context) - CanSend() bool -} - -// VideoTrack is a carrier capability for bidirectional video transport. -type VideoTrack interface { - Connect(ctx context.Context) error - Close() error - SetReconnectCallback(cb func()) - SetShouldReconnect(fn func() bool) - SetEndedCallback(cb func(string)) - WatchConnection(ctx context.Context) - CanSend() bool - AddTrack(track webrtc.TrackLocal) error - SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) -} diff --git a/internal/carrier/carrier.go b/internal/carrier/carrier.go deleted file mode 100644 index cf5e7c8..0000000 --- a/internal/carrier/carrier.go +++ /dev/null @@ -1,81 +0,0 @@ -// Package carrier exposes carrier-oriented registration and construction APIs. -package carrier - -import ( - "context" - "errors" -) - -var ( - // ErrCarrierNotFound is returned when a requested carrier is not registered. - ErrCarrierNotFound = errors.New("carrier not found") - // ErrByteStreamUnsupported is returned when a carrier cannot provide a byte stream. - ErrByteStreamUnsupported = errors.New("carrier does not support byte stream") - // ErrVideoTrackUnsupported is returned when a carrier cannot exchange video tracks. - ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") - // ErrAuthFailed is returned when a carrier's auth provider rejects the request. - ErrAuthFailed = errors.New("carrier auth failed") -) - -// Capabilities describes the transport primitives a carrier can expose. -type Capabilities struct { - ByteStream bool - VideoTrack bool -} - -// Session is the carrier-level runtime handle. -type Session interface { - Capabilities() Capabilities -} - -// ByteStreamCapable is implemented by carriers that can expose a byte stream. -type ByteStreamCapable interface { - OpenByteStream() (ByteStream, error) -} - -// VideoTrackCapable is implemented by carriers that can exchange video tracks. -type VideoTrackCapable interface { - OpenVideoTrack() (VideoTrack, error) -} - -// Config holds carrier configuration. -type Config struct { - RoomURL string - Name string - OnData func([]byte) - DNSServer string - ProxyAddr string - ProxyPort int - // URL, Token, and Engine are used by the "none" auth carrier (direct engine access). - URL string - Token string - Engine string -} - -// Factory creates a new carrier session. -type Factory func(ctx context.Context, cfg Config) (Session, error) - -var registry = make(map[string]Factory) //nolint:gochecknoglobals // package-level state intentional - -// Register adds a carrier factory to the registry. -func Register(name string, factory Factory) { - registry[name] = factory -} - -// New creates a carrier session by name. -func New(ctx context.Context, name string, cfg Config) (Session, error) { - factory, ok := registry[name] - if !ok { - return nil, ErrCarrierNotFound - } - return factory(ctx, cfg) -} - -// Available returns a list of registered carriers. -func Available() []string { - names := make([]string, 0, len(registry)) - for name := range registry { - names = append(names, name) - } - return names -} diff --git a/internal/carrier/carrier_test.go b/internal/carrier/carrier_test.go deleted file mode 100644 index 9244d4b..0000000 --- a/internal/carrier/carrier_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package carrier - -import ( - "context" - "errors" - "reflect" - "testing" -) - -type stubSession struct{} - -func (s *stubSession) Capabilities() Capabilities { - return Capabilities{ByteStream: true, VideoTrack: true} -} - -func snapshotCarrierRegistry() map[string]Factory { - out := make(map[string]Factory, len(registry)) - for k, v := range registry { - out[k] = v - } - return out -} - -func restoreCarrierRegistry(src map[string]Factory) { - registry = make(map[string]Factory, len(src)) - for k, v := range src { - registry[k] = v - } -} - -func TestRegisterAndAvailable(t *testing.T) { - old := snapshotCarrierRegistry() - t.Cleanup(func() { restoreCarrierRegistry(old) }) - - Register("test-carrier", func(_ context.Context, cfg Config) (Session, error) { - if cfg.Name != "peer" { - t.Fatalf("carrier config name = %q, want peer", cfg.Name) - } - return &stubSession{}, nil - }) - - sess, err := New(context.Background(), "test-carrier", Config{Name: "peer"}) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - caps := sess.Capabilities() - if !caps.ByteStream || !caps.VideoTrack { - t.Fatalf("Capabilities() = %+v, want byte and video true", caps) - } - - if !reflect.DeepEqual(Available(), []string{"test-carrier"}) { - t.Fatalf("Available() = %#v, want %#v", Available(), []string{"test-carrier"}) - } -} - -func TestNewReturnsErrCarrierNotFound(t *testing.T) { - old := snapshotCarrierRegistry() - t.Cleanup(func() { restoreCarrierRegistry(old) }) - registry = map[string]Factory{} - - _, err := New(context.Background(), "missing", Config{}) - if !errors.Is(err, ErrCarrierNotFound) { - t.Fatalf("New() error = %v, want %v", err, ErrCarrierNotFound) - } -} diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index 1af02f0..2f2fe38 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -21,9 +21,10 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/auth" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" authSaluteJazz "github.com/openlibrecommunity/olcrtc/internal/auth/salutejazz" authWBStream "github.com/openlibrecommunity/olcrtc/internal/auth/wbstream" - "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/server" "github.com/openlibrecommunity/olcrtc/internal/supervisor" @@ -116,21 +117,10 @@ const ( realE2EExpectUnstable ) -type memorySession struct { - stream *memoryStream -} - -func (s *memorySession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{ByteStream: true, VideoTrack: true} -} - -func (s *memorySession) OpenByteStream() (carrier.ByteStream, error) { - return s.stream, nil -} - -func (s *memorySession) OpenVideoTrack() (carrier.VideoTrack, error) { - return s.stream, nil -} +// memoryStream is registered as an engine.Session directly: it implements +// every Session method plus engine.VideoTrackCapable (AddVideoTrack / +// SetVideoTrackHandler aliases below). The wrapper that used to live in +// memorySession is no longer needed after the carrier-layer collapse. type memoryRoom struct { mu sync.Mutex @@ -271,9 +261,13 @@ func (s *memoryStream) Close() error { return nil } -func (s *memoryStream) SetReconnectCallback(cb func()) { +func (s *memoryStream) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.mu.Lock() - s.reconnect = cb + if cb == nil { + s.reconnect = nil + } else { + s.reconnect = func() { cb(nil) } + } s.mu.Unlock() } func (s *memoryStream) SetShouldReconnect(func() bool) {} @@ -288,15 +282,20 @@ func (s *memoryStream) WatchConnection(ctx context.Context) { func (s *memoryStream) CanSend() bool { return s.isConnected() } +func (s *memoryStream) GetSendQueue() chan []byte { return nil } +func (s *memoryStream) GetBufferedAmount() uint64 { return 0 } +func (s *memoryStream) Capabilities() engine.Capabilities { + return engine.Capabilities{ByteStream: true, VideoTrack: true} +} -func (s *memoryStream) AddTrack(track webrtc.TrackLocal) error { +func (s *memoryStream) AddVideoTrack(track webrtc.TrackLocal) error { s.mu.Lock() s.track = track s.mu.Unlock() return nil } -func (s *memoryStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { +func (s *memoryStream) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { s.mu.Lock() s.trackCB = cb s.mu.Unlock() @@ -334,12 +333,12 @@ func registerMemoryCarrier(t *testing.T) (string, *memoryRoom) { name := "e2e-memory-" + t.Name() room := &memoryRoom{streams: make(map[*memoryStream]struct{})} - carrier.Register(name, func(_ context.Context, cfg carrier.Config) (carrier.Session, error) { + enginebuiltin.Register(name, func(_ context.Context, cfg enginebuiltin.Config) (engine.Session, error) { stream := &memoryStream{room: room, onData: cfg.OnData} room.mu.Lock() room.streams[stream] = struct{}{} room.mu.Unlock() - return &memorySession{stream: stream}, nil + return stream, nil }) return name, room } @@ -348,12 +347,12 @@ func registerMemoryCarrierAs(t *testing.T, name string) { t.Helper() room := &memoryRoom{streams: make(map[*memoryStream]struct{})} - carrier.Register(name, func(_ context.Context, cfg carrier.Config) (carrier.Session, error) { + enginebuiltin.Register(name, func(_ context.Context, cfg enginebuiltin.Config) (engine.Session, error) { stream := &memoryStream{room: room, onData: cfg.OnData} room.mu.Lock() room.streams[stream] = struct{}{} room.mu.Unlock() - return &memorySession{stream: stream}, nil + return stream, nil }) } @@ -362,7 +361,7 @@ func registerFailingCarrier(t *testing.T) string { session.RegisterDefaults() name := "e2e-fail-" + t.Name() - carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { + enginebuiltin.Register(name, func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errFailoverCarrier }) return name @@ -1094,7 +1093,7 @@ func TestRealProviderTransportMatrix(t *testing.T) { expectation := realE2ECaseExpectation(carrierName, transportName) label := realE2EExpectationLabel(expectation) err := runRealE2ECase(t, carrierName, transportName, roomURL, echoAddr) - if err != nil && errors.Is(err, carrier.ErrAuthFailed) { + if err != nil && errors.Is(err, enginebuiltin.ErrAuthFailed) { authFailed = true t.Skipf("skip %s real e2e: auth failed: %v", carrierName, err) } diff --git a/internal/engine/builtin/builtin.go b/internal/engine/builtin/builtin.go new file mode 100644 index 0000000..dc94815 --- /dev/null +++ b/internal/engine/builtin/builtin.go @@ -0,0 +1,148 @@ +// Package builtin wires the built-in auth providers to their engines and +// registers a name-keyed factory that transports use to obtain an +// [engine.Session]. The factory replaces the former carrier layer: when +// the auth provider is "none" the caller supplies engine/URL/token +// directly; otherwise the named provider issues credentials and the +// matching engine is constructed. +package builtin + +import ( + "context" + "errors" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/auth" + authJitsi "github.com/openlibrecommunity/olcrtc/internal/auth/jitsi" + authSaluteJazz "github.com/openlibrecommunity/olcrtc/internal/auth/salutejazz" + authTelemost "github.com/openlibrecommunity/olcrtc/internal/auth/telemost" + authWBStream "github.com/openlibrecommunity/olcrtc/internal/auth/wbstream" + "github.com/openlibrecommunity/olcrtc/internal/engine" + _ "github.com/openlibrecommunity/olcrtc/internal/engine/goolom" // register goolom engine via init + _ "github.com/openlibrecommunity/olcrtc/internal/engine/jitsi" // register jitsi engine via init + _ "github.com/openlibrecommunity/olcrtc/internal/engine/livekit" // register livekit engine via init + _ "github.com/openlibrecommunity/olcrtc/internal/engine/salutejazz" // register salutejazz engine via init +) + +// ErrCarrierNotFound is returned when an unregistered carrier name is requested. +var ErrCarrierNotFound = errors.New("carrier not found") + +// ErrAuthFailed wraps an auth provider rejection. It pairs with the inner +// provider error returned from [Open]. +var ErrAuthFailed = errors.New("carrier auth failed") + +// Config holds the inputs to [Open]. The fields mirror the subset of +// transport.Config that engines consume. +type Config struct { + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int + // Engine, URL, Token are honoured only for the "none" carrier (direct + // engine access); other carriers derive them from their auth provider. + Engine string + URL string + Token string +} + +// Factory creates an engine session for a given carrier. +type Factory func(ctx context.Context, cfg Config) (engine.Session, error) + +var registry = map[string]Factory{} //nolint:gochecknoglobals // package-level registry + +// Register adds a carrier factory. +func Register(name string, f Factory) { + registry[name] = f +} + +// Open looks up the carrier factory and creates an engine session. +func Open(ctx context.Context, name string, cfg Config) (engine.Session, error) { + f, ok := registry[name] + if !ok { + return nil, fmt.Errorf("%w: %q", ErrCarrierNotFound, name) + } + return f(ctx, cfg) +} + +// Available reports all registered carrier names. +func Available() []string { + names := make([]string, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} + +// RegisterDefaults wires the built-in carriers: jitsi, telemost, jazz, wbstream +// and "none" (direct engine access). +func RegisterDefaults() { + registerEngineAuth("wbstream", authWBStream.Provider{}) + registerEngineAuth("jazz", authSaluteJazz.Provider{}) + registerEngineAuth("telemost", authTelemost.Provider{}) + registerEngineAuth("jitsi", authJitsi.Provider{}) + registerDirect("none") +} + +// registerDirect registers a carrier that skips auth: the caller supplies +// engine/URL/token directly via [Config]. +func registerDirect(name string) { + Register(name, func(ctx context.Context, cfg Config) (engine.Session, error) { + engineName := cfg.Engine + if engineName == "" { + engineName = "livekit" + } + sess, err := engine.New(ctx, engineName, engine.Config{ + URL: cfg.URL, + Token: cfg.Token, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("engine new: %w", err) + } + return sess, nil + }) +} + +// registerEngineAuth registers a carrier that resolves credentials through an +// auth provider and connects via the engine the auth provider reports. +func registerEngineAuth(name string, provider auth.Provider) { + Register(name, func(ctx context.Context, cfg Config) (engine.Session, error) { + authCfg := auth.Config{ + RoomURL: cfg.RoomURL, + Name: cfg.Name, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + } + creds, err := provider.Issue(ctx, authCfg) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrAuthFailed, err) + } + sess, err := engine.New(ctx, provider.Engine(), engine.Config{ + URL: creds.URL, + Token: creds.Token, + Name: cfg.Name, + Extra: creds.Extra, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + Refresh: func(ctx context.Context) (engine.Credentials, error) { + fresh, err := provider.Issue(ctx, authCfg) + if err != nil { + return engine.Credentials{}, fmt.Errorf("auth refresh: %w", err) + } + return engine.Credentials{URL: fresh.URL, Token: fresh.Token, Extra: fresh.Extra}, nil + }, + }) + if err != nil { + return nil, fmt.Errorf("engine new: %w", err) + } + return sess, nil + }) +} diff --git a/internal/transport/datachannel/transport.go b/internal/transport/datachannel/transport.go index 8a4f783..4fc2ad7 100644 --- a/internal/transport/datachannel/transport.go +++ b/internal/transport/datachannel/transport.go @@ -1,23 +1,29 @@ -// Package datachannel provides a transport backed by the current carriers. +// Package datachannel provides a transport backed by a carrier's data channel. package datachannel import ( "context" + "errors" "fmt" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/pion/webrtc/v4" ) const defaultMaxPayloadSize = 12 * 1024 +// ErrByteStreamUnsupported is returned when a carrier engine cannot expose a byte stream. +var ErrByteStreamUnsupported = errors.New("engine does not support byte stream") + type streamTransport struct { - stream carrier.ByteStream + session engine.Session } -// New creates a datachannel transport backed by a carrier. +// New creates a datachannel transport backed by a carrier engine. func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { - session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + sess, err := enginebuiltin.Open(ctx, cfg.Carrier, enginebuiltin.Config{ RoomURL: cfg.RoomURL, Name: cfg.Name, OnData: cfg.OnData, @@ -29,69 +35,68 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) Token: cfg.Token, }) if err != nil { - return nil, fmt.Errorf("create carrier transport: %w", err) + return nil, fmt.Errorf("open engine session: %w", err) } - streamCapable, ok := session.(carrier.ByteStreamCapable) - if !ok { - return nil, carrier.ErrByteStreamUnsupported + if !sess.Capabilities().ByteStream { + _ = sess.Close() + return nil, ErrByteStreamUnsupported } - stream, err := streamCapable.OpenByteStream() - if err != nil { - return nil, fmt.Errorf("open byte stream: %w", err) - } - - return &streamTransport{stream: stream}, nil + return &streamTransport{session: sess}, nil } // Connect starts the transport connection. func (p *streamTransport) Connect(ctx context.Context) error { - if err := p.stream.Connect(ctx); err != nil { - return fmt.Errorf("stream connect: %w", err) + if err := p.session.Connect(ctx); err != nil { + return fmt.Errorf("session connect: %w", err) } return nil } // Send transmits data through the transport. func (p *streamTransport) Send(data []byte) error { - if err := p.stream.Send(data); err != nil { - return fmt.Errorf("stream send: %w", err) + if err := p.session.Send(data); err != nil { + return fmt.Errorf("session send: %w", err) } return nil } // Close terminates the transport. func (p *streamTransport) Close() error { - if err := p.stream.Close(); err != nil { - return fmt.Errorf("stream close: %w", err) + if err := p.session.Close(); err != nil { + return fmt.Errorf("session close: %w", err) } return nil } // SetReconnectCallback registers reconnect handling. func (p *streamTransport) SetReconnectCallback(cb func()) { - p.stream.SetReconnectCallback(cb) + p.session.SetReconnectCallback(func(*webrtc.DataChannel) { + if cb != nil { + cb() + } + }) } // SetShouldReconnect configures reconnect policy. func (p *streamTransport) SetShouldReconnect(fn func() bool) { - p.stream.SetShouldReconnect(fn) + p.session.SetShouldReconnect(fn) } // SetEndedCallback registers end-of-session handling. func (p *streamTransport) SetEndedCallback(cb func(string)) { - p.stream.SetEndedCallback(cb) + p.session.SetEndedCallback(cb) } // WatchConnection monitors connection lifecycle. func (p *streamTransport) WatchConnection(ctx context.Context) { - p.stream.WatchConnection(ctx) + p.session.WatchConnection(ctx) } // CanSend reports whether transport is ready for sending. func (p *streamTransport) CanSend() bool { - return p.stream.CanSend() + return p.session.CanSend() } // Features describes the current datachannel transport semantics. diff --git a/internal/transport/datachannel/transport_test.go b/internal/transport/datachannel/transport_test.go index 1f4e4f7..3113f4b 100644 --- a/internal/transport/datachannel/transport_test.go +++ b/internal/transport/datachannel/transport_test.go @@ -5,69 +5,61 @@ import ( "errors" "testing" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/pion/webrtc/v4" ) var ( errDCBoom = errors.New("boom") - errDCOpenBoom = errors.New("open boom") errDCConnectBoom = errors.New("connect boom") errDCSendBoom = errors.New("send boom") errDCCloseBoom = errors.New("close boom") ) type stubSession struct { - stream carrier.ByteStream - streamErr error -} - -func (s *stubSession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{ByteStream: true} -} -func (s *stubSession) OpenByteStream() (carrier.ByteStream, error) { - if s.streamErr != nil { - return nil, s.streamErr - } - return s.stream, nil -} - -type nonByteStreamSession struct{} - -func (s *nonByteStreamSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} } - -type stubByteStream struct { - connectErr error - sendErr error - closeErr error - canSend bool - + caps engine.Capabilities + connectErr error + sendErr error + closeErr error + canSend bool connectCalled bool - sent []byte - watched bool - reconnectCB func() - shouldFn func() bool - endedCB func(string) + sent []byte + watched bool + reconnectCB func(*webrtc.DataChannel) + shouldFn func() bool + endedCB func(string) } -func (s *stubByteStream) Connect(context.Context) error { s.connectCalled = true; return s.connectErr } -func (s *stubByteStream) Send(data []byte) error { +func (s *stubSession) Capabilities() engine.Capabilities { return s.caps } +func (s *stubSession) Connect(context.Context) error { s.connectCalled = true; return s.connectErr } +func (s *stubSession) Send(data []byte) error { s.sent = append([]byte(nil), data...) return s.sendErr } -func (s *stubByteStream) Close() error { return s.closeErr } -func (s *stubByteStream) SetReconnectCallback(cb func()) { s.reconnectCB = cb } -func (s *stubByteStream) SetShouldReconnect(fn func() bool) { s.shouldFn = fn } -func (s *stubByteStream) SetEndedCallback(cb func(string)) { s.endedCB = cb } -func (s *stubByteStream) WatchConnection(context.Context) { s.watched = true } -func (s *stubByteStream) CanSend() bool { return s.canSend } +func (s *stubSession) Close() error { return s.closeErr } +func (s *stubSession) SetReconnectCallback(cb func(*webrtc.DataChannel)) { s.reconnectCB = cb } +func (s *stubSession) SetShouldReconnect(fn func() bool) { s.shouldFn = fn } +func (s *stubSession) SetEndedCallback(cb func(string)) { s.endedCB = cb } +func (s *stubSession) WatchConnection(context.Context) { s.watched = true } +func (s *stubSession) CanSend() bool { return s.canSend } +func (s *stubSession) GetSendQueue() chan []byte { return nil } +func (s *stubSession) GetBufferedAmount() uint64 { return 0 } + +func registerCarrier(name string, sess engine.Session, err error) { + enginebuiltin.Register(name, func(context.Context, enginebuiltin.Config) (engine.Session, error) { + if err != nil { + return nil, err + } + return sess, nil + }) +} //nolint:cyclop // table-driven test naturally has many branches func TestNewAndFeatures(t *testing.T) { - stream := &stubByteStream{canSend: true} - carrier.Register("datachannel-test-new-and-features", func(context.Context, carrier.Config) (carrier.Session, error) { - return &stubSession{stream: stream}, nil - }) + sess := &stubSession{caps: engine.Capabilities{ByteStream: true}, canSend: true} + registerCarrier("datachannel-test-new-and-features", sess, nil) tr, err := New(context.Background(), transport.Config{Carrier: "datachannel-test-new-and-features"}) if err != nil { @@ -77,20 +69,20 @@ func TestNewAndFeatures(t *testing.T) { if err := tr.Connect(context.Background()); err != nil { t.Fatalf("Connect() error = %v", err) } - if !stream.connectCalled { + if !sess.connectCalled { t.Fatal("Connect() was not forwarded") } if err := tr.Send([]byte("payload")); err != nil { t.Fatalf("Send() error = %v", err) } - if string(stream.sent) != "payload" { - t.Fatalf("Send() forwarded %q, want payload", stream.sent) + if string(sess.sent) != "payload" { + t.Fatalf("Send() forwarded %q, want payload", sess.sent) } tr.SetReconnectCallback(func() {}) tr.SetShouldReconnect(func() bool { return true }) tr.SetEndedCallback(func(string) {}) tr.WatchConnection(context.Background()) - if stream.reconnectCB == nil || stream.shouldFn == nil || stream.endedCB == nil || !stream.watched { + if sess.reconnectCB == nil || sess.shouldFn == nil || sess.endedCB == nil || !sess.watched { t.Fatal("callbacks/watch were not forwarded") } if !tr.CanSend() { @@ -107,42 +99,33 @@ func TestNewAndFeatures(t *testing.T) { } func TestNewErrorPaths(t *testing.T) { - carrier.Register("datachannel-fail-create", func(context.Context, carrier.Config) (carrier.Session, error) { - return nil, errDCBoom - }) - if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}); err == nil || err.Error() != "create carrier transport: boom" { //nolint:lll // long test description + registerCarrier("datachannel-fail-create", nil, errDCBoom) + if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-fail-create"}); err == nil || err.Error() != "open engine session: boom" { t.Fatalf("New() error = %v", err) } - carrier.Register("datachannel-no-stream", func(context.Context, carrier.Config) (carrier.Session, error) { - return &nonByteStreamSession{}, nil - }) - if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}); !errors.Is(err, carrier.ErrByteStreamUnsupported) { //nolint:lll // long test description - t.Fatalf("New() error = %v, want %v", err, carrier.ErrByteStreamUnsupported) - } - - carrier.Register("datachannel-open-stream-fails", func(context.Context, carrier.Config) (carrier.Session, error) { - return &stubSession{streamErr: errDCOpenBoom}, nil - }) - if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-open-stream-fails"}); err == nil || err.Error() != "open byte stream: open boom" { //nolint:lll // long test description - t.Fatalf("New() error = %v", err) + nonByteStream := &stubSession{caps: engine.Capabilities{}} + registerCarrier("datachannel-no-stream", nonByteStream, nil) + if _, err := New(context.Background(), transport.Config{Carrier: "datachannel-no-stream"}); !errors.Is(err, ErrByteStreamUnsupported) { + t.Fatalf("New() error = %v, want %v", err, ErrByteStreamUnsupported) } } func TestStreamTransportWrapsErrors(t *testing.T) { - tr := &streamTransport{stream: &stubByteStream{ + tr := &streamTransport{session: &stubSession{ + caps: engine.Capabilities{ByteStream: true}, connectErr: errDCConnectBoom, sendErr: errDCSendBoom, closeErr: errDCCloseBoom, }} - if err := tr.Connect(context.Background()); err == nil || err.Error() != "stream connect: connect boom" { + if err := tr.Connect(context.Background()); err == nil || err.Error() != "session connect: connect boom" { t.Fatalf("Connect() error = %v", err) } - if err := tr.Send([]byte("x")); err == nil || err.Error() != "stream send: send boom" { + if err := tr.Send([]byte("x")); err == nil || err.Error() != "session send: send boom" { t.Fatalf("Send() error = %v", err) } - if err := tr.Close(); err == nil || err.Error() != "stream close: close boom" { + if err := tr.Close(); err == nil || err.Error() != "session close: close boom" { t.Fatalf("Close() error = %v", err) } } diff --git a/internal/transport/seichannel/engine_session.go b/internal/transport/seichannel/engine_session.go new file mode 100644 index 0000000..59fbb83 --- /dev/null +++ b/internal/transport/seichannel/engine_session.go @@ -0,0 +1,56 @@ +package seichannel + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/pion/webrtc/v4" +) + +// engineVideoSession adapts engine.Session + engine.VideoTrackCapable to the +// videoSession interface seichannel consumes. +type engineVideoSession struct { + session engine.Session + vt engine.VideoTrackCapable +} + +func (v *engineVideoSession) Connect(ctx context.Context) error { + if err := v.session.Connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + return nil +} + +func (v *engineVideoSession) Close() error { + if err := v.session.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetReconnectCallback(cb func()) { + v.session.SetReconnectCallback(func(*webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} + +func (v *engineVideoSession) SetShouldReconnect(fn func() bool) { v.session.SetShouldReconnect(fn) } +func (v *engineVideoSession) SetEndedCallback(cb func(string)) { v.session.SetEndedCallback(cb) } +func (v *engineVideoSession) WatchConnection(ctx context.Context) { + v.session.WatchConnection(ctx) +} +func (v *engineVideoSession) CanSend() bool { return v.session.CanSend() } + +func (v *engineVideoSession) AddTrack(track webrtc.TrackLocal) error { + if err := v.vt.AddVideoTrack(track); err != nil { + return fmt.Errorf("add track: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + v.vt.SetVideoTrackHandler(cb) +} diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index 0f9bbfc..f4f9620 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -13,7 +13,8 @@ import ( "sync/atomic" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/rtp/codecs" "github.com/pion/webrtc/v4" @@ -76,8 +77,22 @@ type inboundMessage struct { remain int } +// videoSession is the subset of engine.Session + engine.VideoTrackCapable the +// seichannel transport relies on. +type videoSession interface { + Connect(ctx context.Context) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool + AddTrack(track webrtc.TrackLocal) error + SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) +} + type streamTransport struct { - stream carrier.VideoTrack + stream videoSession track *webrtc.TrackLocalStaticSample onData func([]byte) outbound chan []byte @@ -108,7 +123,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) return nil, err } - session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + session, err := enginebuiltin.Open(ctx, cfg.Carrier, enginebuiltin.Config{ RoomURL: cfg.RoomURL, Name: cfg.Name, OnData: nil, @@ -120,18 +135,15 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) Token: cfg.Token, }) if err != nil { - return nil, fmt.Errorf("create carrier transport: %w", err) + return nil, fmt.Errorf("open engine session: %w", err) } - videoCapable, ok := session.(carrier.VideoTrackCapable) - if !ok { + vt, ok := session.(engine.VideoTrackCapable) + if !ok || !session.Capabilities().VideoTrack { + _ = session.Close() return nil, ErrVideoTrackUnsupported } - - stream, err := videoCapable.OpenVideoTrack() - if err != nil { - return nil, fmt.Errorf("open video track: %w", err) - } + stream := &engineVideoSession{session: session, vt: vt} // Stream/track IDs must be unique per peer — Jitsi rejects session-accept // when msid collides with another participant in the conference. diff --git a/internal/transport/seichannel/transport_unit_test.go b/internal/transport/seichannel/transport_unit_test.go index 0310887..c055d01 100644 --- a/internal/transport/seichannel/transport_unit_test.go +++ b/internal/transport/seichannel/transport_unit_test.go @@ -7,31 +7,16 @@ import ( "testing" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/webrtc/v4" ) -var ( - errBoom = errors.New("boom") - errOpenBoom = errors.New("open boom") -) - -type fakeVideoSession struct { - stream *fakeVideoStream - err error -} - -func (s *fakeVideoSession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{VideoTrack: true} -} -func (s *fakeVideoSession) OpenVideoTrack() (carrier.VideoTrack, error) { - if s.err != nil { - return nil, s.err - } - return s.stream, nil -} +var errBoom = errors.New("boom") +// fakeVideoStream is the stub implementation of the videoSession interface +// the seichannel transport consumes after engine.Session adaptation. type fakeVideoStream struct { connectErr error closeErr error @@ -61,16 +46,49 @@ func (s *fakeVideoStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.R s.trackCB = cb } -type nonVideoSession struct{} +// fakeEngineSession implements engine.Session and engine.VideoTrackCapable so +// it can be returned by enginebuiltin.Open in tests. It wraps a fakeVideoStream +// for the video-track methods the real engine session exposes. +type fakeEngineSession struct { + stream *fakeVideoStream + noVideo bool +} -func (s *nonVideoSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} } +func (s *fakeEngineSession) Capabilities() engine.Capabilities { + if s.noVideo { + return engine.Capabilities{} + } + return engine.Capabilities{VideoTrack: true} +} +func (s *fakeEngineSession) Connect(ctx context.Context) error { return s.stream.Connect(ctx) } +func (s *fakeEngineSession) Send([]byte) error { return nil } +func (s *fakeEngineSession) Close() error { return s.stream.Close() } +func (s *fakeEngineSession) SetReconnectCallback(cb func(*webrtc.DataChannel)) { + s.stream.SetReconnectCallback(func() { + if cb != nil { + cb(nil) + } + }) +} +func (s *fakeEngineSession) SetShouldReconnect(fn func() bool) { s.stream.SetShouldReconnect(fn) } +func (s *fakeEngineSession) SetEndedCallback(cb func(string)) { s.stream.SetEndedCallback(cb) } +func (s *fakeEngineSession) WatchConnection(ctx context.Context) { + s.stream.WatchConnection(ctx) +} +func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() } +func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil } +func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 } +func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) } +func (s *fakeEngineSession) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + s.stream.SetTrackHandler(cb) +} //nolint:cyclop // table-driven test naturally has many branches func TestNewConnectCallbacksAndFeatures(t *testing.T) { stream := &fakeVideoStream{canSend: true} name := "seichannel-unit-new" - carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{stream: stream}, nil + enginebuiltin.Register(name, func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: stream}, nil }) trIface, err := New(t.Context(), transport.Config{ @@ -126,26 +144,19 @@ func TestNewConnectCallbacksAndFeatures(t *testing.T) { } func TestNewErrorPaths(t *testing.T) { - carrier.Register("seichannel-create-fails", func(context.Context, carrier.Config) (carrier.Session, error) { + enginebuiltin.Register("seichannel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errBoom }) - if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-create-fails"}); err == nil || err.Error() != "create carrier transport: boom" { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description t.Fatalf("New() error = %v", err) } - carrier.Register("seichannel-no-video", func(context.Context, carrier.Config) (carrier.Session, error) { - return &nonVideoSession{}, nil + enginebuiltin.Register("seichannel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } - - carrier.Register("seichannel-open-fails", func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{err: errOpenBoom}, nil - }) - if _, err := New(context.Background(), transport.Config{Carrier: "seichannel-open-fails"}); err == nil || err.Error() != "open video track: open boom" { //nolint:lll // long test description - t.Fatalf("New() error = %v", err) - } } func TestSendAckAndClosePaths(t *testing.T) { diff --git a/internal/transport/videochannel/engine_session.go b/internal/transport/videochannel/engine_session.go new file mode 100644 index 0000000..2b3e411 --- /dev/null +++ b/internal/transport/videochannel/engine_session.go @@ -0,0 +1,59 @@ +package videochannel + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/pion/webrtc/v4" +) + +// engineVideoSession adapts engine.Session + engine.VideoTrackCapable to the +// videoSession interface the videochannel transport consumes. The wrapper +// drops the *webrtc.DataChannel argument from the engine reconnect callback +// (videochannel does not use data channels) and exposes the video-track +// helpers under shorter names. +type engineVideoSession struct { + session engine.Session + vt engine.VideoTrackCapable +} + +func (v *engineVideoSession) Connect(ctx context.Context) error { + if err := v.session.Connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + return nil +} + +func (v *engineVideoSession) Close() error { + if err := v.session.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetReconnectCallback(cb func()) { + v.session.SetReconnectCallback(func(*webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} + +func (v *engineVideoSession) SetShouldReconnect(fn func() bool) { v.session.SetShouldReconnect(fn) } +func (v *engineVideoSession) SetEndedCallback(cb func(string)) { v.session.SetEndedCallback(cb) } +func (v *engineVideoSession) WatchConnection(ctx context.Context) { + v.session.WatchConnection(ctx) +} +func (v *engineVideoSession) CanSend() bool { return v.session.CanSend() } + +func (v *engineVideoSession) AddTrack(track webrtc.TrackLocal) error { + if err := v.vt.AddVideoTrack(track); err != nil { + return fmt.Errorf("add track: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + v.vt.SetVideoTrackHandler(cb) +} diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go index e1ad18f..5bb5288 100644 --- a/internal/transport/videochannel/transport.go +++ b/internal/transport/videochannel/transport.go @@ -12,7 +12,8 @@ import ( "sync/atomic" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/webrtc/v4" @@ -39,8 +40,22 @@ var ( ErrTransportClosed = errors.New("videochannel transport closed") ) +// videoSession is the subset of engine.Session + engine.VideoTrackCapable +// the videochannel transport relies on. +type videoSession interface { + Connect(ctx context.Context) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool + AddTrack(track webrtc.TrackLocal) error + SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) +} + type streamTransport struct { - stream carrier.VideoTrack + stream videoSession track *webrtc.TrackLocalStaticSample codec codecSpec encoder *ffmpegEncoder @@ -81,14 +96,14 @@ type streamTransport struct { idleFrameMu sync.Mutex } -// New creates a visual videochannel transport backed by a carrier. +// New creates a visual videochannel transport backed by a carrier engine. func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { opts, err := optionsFrom(cfg) if err != nil { return nil, err } - session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + session, err := enginebuiltin.Open(ctx, cfg.Carrier, enginebuiltin.Config{ RoomURL: cfg.RoomURL, Name: cfg.Name, OnData: nil, @@ -100,18 +115,15 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) Token: cfg.Token, }) if err != nil { - return nil, fmt.Errorf("create carrier transport: %w", err) + return nil, fmt.Errorf("open engine session: %w", err) } - videoCapable, ok := session.(carrier.VideoTrackCapable) - if !ok { + vt, ok := session.(engine.VideoTrackCapable) + if !ok || !session.Capabilities().VideoTrack { + _ = session.Close() return nil, ErrVideoTrackUnsupported } - - stream, err := videoCapable.OpenVideoTrack() - if err != nil { - return nil, fmt.Errorf("open video track: %w", err) - } + stream := &engineVideoSession{session: session, vt: vt} codec := codecSpecForCarrier(cfg.Carrier) // Stream/track IDs must be unique per peer: Jitsi/Jicofo keys participant diff --git a/internal/transport/videochannel/transport_unit_test.go b/internal/transport/videochannel/transport_unit_test.go index 00420c6..e0050a8 100644 --- a/internal/transport/videochannel/transport_unit_test.go +++ b/internal/transport/videochannel/transport_unit_test.go @@ -7,30 +7,13 @@ import ( "testing" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/webrtc/v4" ) -var ( - errVideoUnitBoom = errors.New("boom") - errVideoUnitOpenBoom = errors.New("open boom") -) - -type fakeVideoSession struct { - stream *fakeVideoStream - err error -} - -func (s *fakeVideoSession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{VideoTrack: true} -} -func (s *fakeVideoSession) OpenVideoTrack() (carrier.VideoTrack, error) { - if s.err != nil { - return nil, s.err - } - return s.stream, nil -} +var errVideoUnitBoom = errors.New("boom") type fakeVideoStream struct { closeErr error @@ -56,16 +39,49 @@ func (s *fakeVideoStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.R s.trackCB = cb } -type nonVideoSession struct{} +// fakeEngineSession adapts fakeVideoStream so it satisfies engine.Session and +// engine.VideoTrackCapable, the two interfaces the videochannel transport +// looks up after the carrier-layer collapse. +type fakeEngineSession struct { + stream *fakeVideoStream + noVideo bool +} -func (s *nonVideoSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} } +func (s *fakeEngineSession) Capabilities() engine.Capabilities { + if s.noVideo { + return engine.Capabilities{} + } + return engine.Capabilities{VideoTrack: true} +} +func (s *fakeEngineSession) Connect(ctx context.Context) error { return s.stream.Connect(ctx) } +func (s *fakeEngineSession) Send([]byte) error { return nil } +func (s *fakeEngineSession) Close() error { return s.stream.Close() } +func (s *fakeEngineSession) SetReconnectCallback(cb func(*webrtc.DataChannel)) { + s.stream.SetReconnectCallback(func() { + if cb != nil { + cb(nil) + } + }) +} +func (s *fakeEngineSession) SetShouldReconnect(fn func() bool) { s.stream.SetShouldReconnect(fn) } +func (s *fakeEngineSession) SetEndedCallback(cb func(string)) { s.stream.SetEndedCallback(cb) } +func (s *fakeEngineSession) WatchConnection(ctx context.Context) { + s.stream.WatchConnection(ctx) +} +func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() } +func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil } +func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 } +func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) } +func (s *fakeEngineSession) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + s.stream.SetTrackHandler(cb) +} //nolint:cyclop // table-driven test naturally has many branches func TestNewCallbacksFeaturesAndClose(t *testing.T) { stream := &fakeVideoStream{canSend: true} name := "videochannel-unit-new" - carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{stream: stream}, nil + enginebuiltin.Register(name, func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: stream}, nil }) trIface, err := New(context.Background(), transport.Config{ @@ -112,26 +128,19 @@ func TestNewCallbacksFeaturesAndClose(t *testing.T) { } func TestNewErrorPaths(t *testing.T) { - carrier.Register("videochannel-create-fails", func(context.Context, carrier.Config) (carrier.Session, error) { + enginebuiltin.Register("videochannel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errVideoUnitBoom }) - if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-create-fails"}); err == nil || err.Error() != "create carrier transport: boom" { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description t.Fatalf("New() error = %v", err) } - carrier.Register("videochannel-no-video", func(context.Context, carrier.Config) (carrier.Session, error) { - return &nonVideoSession{}, nil + enginebuiltin.Register("videochannel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } - - carrier.Register("videochannel-open-fails", func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{err: errVideoUnitOpenBoom}, nil - }) - if _, err := New(context.Background(), transport.Config{Carrier: "videochannel-open-fails"}); err == nil || err.Error() != "open video track: open boom" { //nolint:lll // long test description - t.Fatalf("New() error = %v", err) - } } func TestSendAckAndClosePaths(t *testing.T) { diff --git a/internal/transport/vp8channel/engine_session.go b/internal/transport/vp8channel/engine_session.go new file mode 100644 index 0000000..3b1a231 --- /dev/null +++ b/internal/transport/vp8channel/engine_session.go @@ -0,0 +1,56 @@ +package vp8channel + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/engine" + "github.com/pion/webrtc/v4" +) + +// engineVideoSession adapts engine.Session + engine.VideoTrackCapable to the +// videoSession interface vp8channel consumes. +type engineVideoSession struct { + session engine.Session + vt engine.VideoTrackCapable +} + +func (v *engineVideoSession) Connect(ctx context.Context) error { + if err := v.session.Connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + return nil +} + +func (v *engineVideoSession) Close() error { + if err := v.session.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetReconnectCallback(cb func()) { + v.session.SetReconnectCallback(func(*webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} + +func (v *engineVideoSession) SetShouldReconnect(fn func() bool) { v.session.SetShouldReconnect(fn) } +func (v *engineVideoSession) SetEndedCallback(cb func(string)) { v.session.SetEndedCallback(cb) } +func (v *engineVideoSession) WatchConnection(ctx context.Context) { + v.session.WatchConnection(ctx) +} +func (v *engineVideoSession) CanSend() bool { return v.session.CanSend() } + +func (v *engineVideoSession) AddTrack(track webrtc.TrackLocal) error { + if err := v.vt.AddVideoTrack(track); err != nil { + return fmt.Errorf("add track: %w", err) + } + return nil +} + +func (v *engineVideoSession) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + v.vt.SetVideoTrackHandler(cb) +} diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index b3996d2..b0df02d 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -38,7 +38,8 @@ import ( "sync/atomic" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/rtp" @@ -87,8 +88,22 @@ const ( epochHdrLen = 32 ) +// videoSession is the subset of engine.Session + engine.VideoTrackCapable +// the vp8channel transport relies on. +type videoSession interface { + Connect(ctx context.Context) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool + AddTrack(track webrtc.TrackLocal) error + SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) +} + type streamTransport struct { - stream carrier.VideoTrack + stream videoSession track *webrtc.TrackLocalStaticSample onData func([]byte) outbound chan []byte @@ -115,14 +130,14 @@ type streamTransport struct { reconnectFn func() } -// New creates a vp8channel transport backed by a carrier. +// New creates a vp8channel transport backed by a carrier engine. func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { opts, err := optionsFrom(cfg) if err != nil { return nil, err } - session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + session, err := enginebuiltin.Open(ctx, cfg.Carrier, enginebuiltin.Config{ RoomURL: cfg.RoomURL, Name: cfg.Name, OnData: nil, @@ -134,18 +149,15 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) Token: cfg.Token, }) if err != nil { - return nil, fmt.Errorf("create carrier transport: %w", err) + return nil, fmt.Errorf("open engine session: %w", err) } - videoCapable, ok := session.(carrier.VideoTrackCapable) - if !ok { + vt, ok := session.(engine.VideoTrackCapable) + if !ok || !session.Capabilities().VideoTrack { + _ = session.Close() return nil, ErrVideoTrackUnsupported } - - stream, err := videoCapable.OpenVideoTrack() - if err != nil { - return nil, fmt.Errorf("open video track: %w", err) - } + stream := &engineVideoSession{session: session, vt: vt} // Stream/track IDs must be unique per peer — Jitsi rejects session-accept // when msid collides with another participant in the conference. diff --git a/internal/transport/vp8channel/transport_unit_test.go b/internal/transport/vp8channel/transport_unit_test.go index 427111e..7821232 100644 --- a/internal/transport/vp8channel/transport_unit_test.go +++ b/internal/transport/vp8channel/transport_unit_test.go @@ -8,21 +8,14 @@ import ( "testing" "time" - "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/pion/rtp" "github.com/pion/webrtc/v4" ) -var ( - errVP8UnitBoom = errors.New("boom") - errVP8UnitOpenBoom = errors.New("open boom") -) - -type fakeVideoSession struct { - stream *fakeVideoStream - err error -} +var errVP8UnitBoom = errors.New("boom") func TestSampleIntervalWithBatch(t *testing.T) { tr := &streamTransport{ @@ -40,16 +33,6 @@ func TestSampleIntervalWithBatch(t *testing.T) { } } -func (s *fakeVideoSession) Capabilities() carrier.Capabilities { - return carrier.Capabilities{VideoTrack: true} -} -func (s *fakeVideoSession) OpenVideoTrack() (carrier.VideoTrack, error) { - if s.err != nil { - return nil, s.err - } - return s.stream, nil -} - type fakeVideoStream struct { connectErr error closeErr error @@ -78,16 +61,49 @@ func (s *fakeVideoStream) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.R s.trackCB = cb } -type nonVideoSession struct{} +// fakeEngineSession adapts fakeVideoStream so it satisfies engine.Session and +// engine.VideoTrackCapable, the two interfaces the vp8channel transport +// looks up after the carrier-layer collapse. +type fakeEngineSession struct { + stream *fakeVideoStream + noVideo bool +} -func (s *nonVideoSession) Capabilities() carrier.Capabilities { return carrier.Capabilities{} } +func (s *fakeEngineSession) Capabilities() engine.Capabilities { + if s.noVideo { + return engine.Capabilities{} + } + return engine.Capabilities{VideoTrack: true} +} +func (s *fakeEngineSession) Connect(ctx context.Context) error { return s.stream.Connect(ctx) } +func (s *fakeEngineSession) Send([]byte) error { return nil } +func (s *fakeEngineSession) Close() error { return s.stream.Close() } +func (s *fakeEngineSession) SetReconnectCallback(cb func(*webrtc.DataChannel)) { + s.stream.SetReconnectCallback(func() { + if cb != nil { + cb(nil) + } + }) +} +func (s *fakeEngineSession) SetShouldReconnect(fn func() bool) { s.stream.SetShouldReconnect(fn) } +func (s *fakeEngineSession) SetEndedCallback(cb func(string)) { s.stream.SetEndedCallback(cb) } +func (s *fakeEngineSession) WatchConnection(ctx context.Context) { + s.stream.WatchConnection(ctx) +} +func (s *fakeEngineSession) CanSend() bool { return s.stream.CanSend() } +func (s *fakeEngineSession) GetSendQueue() chan []byte { return nil } +func (s *fakeEngineSession) GetBufferedAmount() uint64 { return 0 } +func (s *fakeEngineSession) AddVideoTrack(t webrtc.TrackLocal) error { return s.stream.AddTrack(t) } +func (s *fakeEngineSession) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + s.stream.SetTrackHandler(cb) +} //nolint:cyclop // table-driven test naturally has many branches func TestNewConnectSendCallbacksFeaturesAndClose(t *testing.T) { stream := &fakeVideoStream{canSend: true} name := "vp8channel-unit-new" - carrier.Register(name, func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{stream: stream}, nil + enginebuiltin.Register(name, func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: stream}, nil }) trIface, err := New(context.Background(), transport.Config{ @@ -150,26 +166,19 @@ func TestNewConnectSendCallbacksFeaturesAndClose(t *testing.T) { } func TestNewErrorPaths(t *testing.T) { - carrier.Register("vp8channel-create-fails", func(context.Context, carrier.Config) (carrier.Session, error) { + enginebuiltin.Register("vp8channel-create-fails", func(context.Context, enginebuiltin.Config) (engine.Session, error) { return nil, errVP8UnitBoom }) - if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-create-fails"}); err == nil || err.Error() != "create carrier transport: boom" { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-create-fails"}); err == nil || err.Error() != "open engine session: boom" { //nolint:lll // long test description t.Fatalf("New() error = %v", err) } - carrier.Register("vp8channel-no-video", func(context.Context, carrier.Config) (carrier.Session, error) { - return &nonVideoSession{}, nil + enginebuiltin.Register("vp8channel-no-video", func(context.Context, enginebuiltin.Config) (engine.Session, error) { + return &fakeEngineSession{stream: &fakeVideoStream{}, noVideo: true}, nil }) - if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { //nolint:lll // long test description + if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-no-video"}); !errors.Is(err, ErrVideoTrackUnsupported) { t.Fatalf("New() error = %v, want %v", err, ErrVideoTrackUnsupported) } - - carrier.Register("vp8channel-open-fails", func(context.Context, carrier.Config) (carrier.Session, error) { - return &fakeVideoSession{err: errVP8UnitOpenBoom}, nil - }) - if _, err := New(context.Background(), transport.Config{Carrier: "vp8channel-open-fails"}); err == nil || err.Error() != "open video track: open boom" { //nolint:lll // long test description - t.Fatalf("New() error = %v", err) - } } //nolint:cyclop // table-driven test naturally has many branches diff --git a/pkg/olcrtc/olcrtc.go b/pkg/olcrtc/olcrtc.go index b0b442f..dee25dc 100644 --- a/pkg/olcrtc/olcrtc.go +++ b/pkg/olcrtc/olcrtc.go @@ -34,8 +34,8 @@ import ( "net" "github.com/openlibrecommunity/olcrtc/internal/auth" - "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" "github.com/openlibrecommunity/olcrtc/internal/engine" + enginebuiltin "github.com/openlibrecommunity/olcrtc/internal/engine/builtin" ) var ( @@ -88,7 +88,7 @@ type Session struct { // Call once at program start if you want the full set without manual blank // imports. Safe to call multiple times. func RegisterDefaults() { - builtin.Register() + enginebuiltin.RegisterDefaults() } // New creates a Session from cfg. The session is not connected yet; call