From 9e09975165ebd9abc94659bd8df75ff5efd2b643 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Tue, 21 Apr 2026 01:32:17 +0300 Subject: [PATCH] feat: implement video channel transport --- cmd/olcrtc/main.go | 14 +- go.mod | 5 +- go.sum | 6 + internal/app/session/session.go | 16 + internal/carrier/bytestream.go | 38 +- internal/carrier/carrier.go | 4 +- internal/client/client.go | 53 +- internal/link/direct/direct.go | 18 +- internal/link/link.go | 20 +- internal/provider/jazz/peer.go | 175 +++++- internal/provider/jazz/provider.go | 7 +- internal/provider/provider.go | 5 +- internal/provider/telemost/peer.go | 419 ++++++++++++-- internal/provider/telemost/provider.go | 6 +- internal/provider/wbstream/peer.go | 61 +- internal/provider/wbstream/provider.go | 6 +- internal/server/server.go | 28 +- internal/transport/seichannel/h264.go | 188 ++++++ internal/transport/seichannel/transport.go | 534 ++++++++++++++++++ .../transport/seichannel/transport_test.go | 42 ++ internal/transport/transport.go | 18 +- internal/transport/videochannel/ffmpeg.go | 444 +++++++++++++++ internal/transport/videochannel/frame.go | 110 ++++ internal/transport/videochannel/transport.go | 478 ++++++++++++++++ .../transport/videochannel/transport_test.go | 51 ++ internal/transport/videochannel/visual.go | 112 ++++ 26 files changed, 2716 insertions(+), 142 deletions(-) create mode 100644 internal/transport/seichannel/h264.go create mode 100644 internal/transport/seichannel/transport.go create mode 100644 internal/transport/seichannel/transport_test.go create mode 100644 internal/transport/videochannel/ffmpeg.go create mode 100644 internal/transport/videochannel/frame.go create mode 100644 internal/transport/videochannel/transport.go create mode 100644 internal/transport/videochannel/transport_test.go create mode 100644 internal/transport/videochannel/visual.go diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 10b5068..9cba868 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -31,6 +31,10 @@ type config struct { dnsServer string socksProxyAddr string socksProxyPort int + videoWidth int + videoHeight int + videoFPS int + videoBitrate string } func main() { @@ -85,7 +89,7 @@ func parseFlags() config { flag.StringVar(&cfg.mode, "mode", "", "Mode: srv or cnc") flag.StringVar(&cfg.link, "link", "direct", "Link: direct") - flag.StringVar(&cfg.transport, "transport", "datachannel", "Transport: datachannel") + flag.StringVar(&cfg.transport, "transport", "datachannel", "Transport: datachannel, videochannel, seichannel") flag.StringVar(&cfg.carrier, "carrier", "", "Carrier: telemost, jazz, wb_stream") flag.StringVar(&cfg.roomID, "id", "", "Room ID") flag.StringVar(&cfg.provider, "provider", "", "Deprecated alias for -carrier") @@ -97,6 +101,10 @@ func parseFlags() config { flag.StringVar(&cfg.dnsServer, "dns", "1.1.1.1:53", "DNS server (default: Cloudflare 1.1.1.1)") flag.StringVar(&cfg.socksProxyAddr, "socks-proxy", "", "SOCKS5 proxy address (server only)") flag.IntVar(&cfg.socksProxyPort, "socks-proxy-port", 1080, "SOCKS5 proxy port (server only)") + flag.IntVar(&cfg.videoWidth, "video-w", 640, "Video logical width (videochannel only)") + flag.IntVar(&cfg.videoHeight, "video-h", 360, "Video logical height (videochannel only)") + flag.IntVar(&cfg.videoFPS, "video-fps", 25, "Video frames per second (videochannel only)") + flag.StringVar(&cfg.videoBitrate, "video-bitrate", "2048k", "Video bitrate (videochannel only)") flag.Parse() return cfg @@ -144,6 +152,10 @@ func toSessionConfig(cfg config) session.Config { DNSServer: cfg.dnsServer, SOCKSProxyAddr: cfg.socksProxyAddr, SOCKSProxyPort: cfg.socksProxyPort, + VideoWidth: cfg.videoWidth, + VideoHeight: cfg.videoHeight, + VideoFPS: cfg.videoFPS, + VideoBitrate: cfg.videoBitrate, } } diff --git a/go.mod b/go.mod index 3608bb5..cbd798f 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,13 @@ module github.com/openlibrecommunity/olcrtc go 1.25.0 require ( + github.com/boombuler/barcode v1.1.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 github.com/livekit/server-sdk-go/v2 v2.16.2 github.com/magefile/mage v1.17.1 + github.com/makiuchi-d/gozxing v0.1.1 + github.com/pion/rtp v1.10.1 github.com/pion/webrtc/v4 v4.2.11 golang.org/x/crypto v0.50.0 golang.org/x/mobile v0.0.0-20260410095206-2cfb76559b7b @@ -50,7 +53,6 @@ require ( github.com/pion/mdns/v2 v2.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/rtcp v1.2.16 // indirect - github.com/pion/rtp v1.10.1 // indirect github.com/pion/sctp v1.9.4 // indirect github.com/pion/sdp/v3 v3.0.18 // indirect github.com/pion/srtp/v3 v3.0.10 // indirect @@ -75,6 +77,7 @@ require ( golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.44.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect google.golang.org/grpc v1.79.1 // indirect diff --git a/go.sum b/go.sum index 4578330..33b2008 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= +github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= +github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -105,6 +107,8 @@ github.com/livekit/server-sdk-go/v2 v2.16.2 h1:eQe24cka3X+5zUivezyL72nwtAJTWFXgi github.com/livekit/server-sdk-go/v2 v2.16.2/go.mod h1:/HOUG9AXJeCbMCdtw0dr37AB+3xXUlj/OLeXS/0p7rA= github.com/magefile/mage v1.17.1 h1:F1d2lnLSlbQDM0Plq6Ac4NtaHxkxTK8t5nrMY9SkoNA= github.com/magefile/mage v1.17.1/go.mod h1:Yj51kqllmsgFpvvSzgrZPK9WtluG3kUhFaBUVLo4feA= +github.com/makiuchi-d/gozxing v0.1.1 h1:xxqijhoedi+/lZlhINteGbywIrewVdVv2wl9r5O9S1I= +github.com/makiuchi-d/gozxing v0.1.1/go.mod h1:eRIHbOjX7QWxLIDJoQuMLhuXg9LAuw6znsUtRkNw9DU= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/moby/api v1.52.0 h1:00BtlJY4MXkkt84WhUZPRqt5TvPbgig2FZvTbe3igYg= @@ -276,6 +280,8 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac= diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 3f3e4dd..8411f0e 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -14,6 +14,8 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/server" "github.com/openlibrecommunity/olcrtc/internal/transport" "github.com/openlibrecommunity/olcrtc/internal/transport/datachannel" + "github.com/openlibrecommunity/olcrtc/internal/transport/seichannel" + "github.com/openlibrecommunity/olcrtc/internal/transport/videochannel" ) var ( @@ -44,6 +46,10 @@ type Config struct { DNSServer string SOCKSProxyAddr string SOCKSProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string } // RegisterDefaults registers built-in providers and transports. @@ -51,6 +57,8 @@ func RegisterDefaults() { builtin.Register() link.Register("direct", direct.New) transport.Register("datachannel", datachannel.New) + transport.Register("videochannel", videochannel.New) + transport.Register("seichannel", seichannel.New) } // Validate verifies that the runtime config refers to registered components. @@ -116,6 +124,10 @@ func Run(ctx context.Context, cfg Config) error { cfg.DNSServer, cfg.SOCKSProxyAddr, cfg.SOCKSProxyPort, + cfg.VideoWidth, + cfg.VideoHeight, + cfg.VideoFPS, + cfg.VideoBitrate, ) case "cnc": return client.Run( @@ -129,6 +141,10 @@ func Run(ctx context.Context, cfg Config) error { cfg.DNSServer, "", "", + cfg.VideoWidth, + cfg.VideoHeight, + cfg.VideoFPS, + cfg.VideoBitrate, ) default: return ErrModeRequired diff --git a/internal/carrier/bytestream.go b/internal/carrier/bytestream.go index 02a584d..ebfe5dc 100644 --- a/internal/carrier/bytestream.go +++ b/internal/carrier/bytestream.go @@ -19,9 +19,17 @@ type ByteStream interface { CanSend() bool } -// VideoTrack is a carrier capability for publishing a local video track. +// VideoTrack is a carrier capability for bidirectional video transport. type VideoTrack interface { - AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) + Connect(ctx context.Context) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool + AddTrack(track webrtc.TrackLocal) error + SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) } type legacySession struct { @@ -76,6 +84,30 @@ type legacyVideoTrack struct { provider provider.VideoTrackCapable } -func (v *legacyVideoTrack) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (v *legacyVideoTrack) Connect(ctx context.Context) error { + return v.provider.(provider.Provider).Connect(ctx) +} +func (v *legacyVideoTrack) Close() error { return v.provider.(provider.Provider).Close() } +func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) { + v.provider.(provider.Provider).SetShouldReconnect(fn) +} +func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) { + v.provider.(provider.Provider).SetEndedCallback(cb) +} +func (v *legacyVideoTrack) WatchConnection(ctx context.Context) { + v.provider.(provider.Provider).WatchConnection(ctx) +} +func (v *legacyVideoTrack) CanSend() bool { return v.provider.(provider.Provider).CanSend() } +func (v *legacyVideoTrack) AddTrack(track webrtc.TrackLocal) error { return v.provider.AddVideoTrack(track) } +func (v *legacyVideoTrack) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + v.provider.SetVideoTrackHandler(cb) +} +func (v *legacyVideoTrack) SetReconnectCallback(cb func()) { + v.provider.(provider.Provider).SetReconnectCallback(func(_ *webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} diff --git a/internal/carrier/carrier.go b/internal/carrier/carrier.go index cbc8e38..98ed15e 100644 --- a/internal/carrier/carrier.go +++ b/internal/carrier/carrier.go @@ -13,7 +13,7 @@ var ( ErrCarrierNotFound = errors.New("carrier not found") // ErrByteStreamUnsupported is returned when a carrier cannot provide a byte stream. ErrByteStreamUnsupported = errors.New("carrier does not support byte stream") - // ErrVideoTrackUnsupported is returned when a carrier cannot publish video tracks. + // ErrVideoTrackUnsupported is returned when a carrier cannot exchange video tracks. ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") ) @@ -33,7 +33,7 @@ type ByteStreamCapable interface { OpenByteStream() (ByteStream, error) } -// VideoTrackCapable is implemented by carriers that can publish video tracks. +// VideoTrackCapable is implemented by carriers that can exchange video tracks. type VideoTrackCapable interface { OpenVideoTrack() (VideoTrack, error) } diff --git a/internal/client/client.go b/internal/client/client.go index 5e30e4e..9eadd0d 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -67,8 +67,12 @@ func Run( dnsServer, socksUser string, socksPass string, + videoWidth int, + videoHeight int, + videoFPS int, + videoBitrate string, ) error { - return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil) + return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. @@ -84,6 +88,10 @@ func RunWithReady( _ string, _ string, onReady func(), + videoWidth int, + videoHeight int, + videoFPS int, + videoBitrate string, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -111,7 +119,7 @@ func RunWithReady( const linkCount = 1 for i := range linkCount { - if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0); err != nil { + if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate); err != nil { return fmt.Errorf("addLink failed: %w", err) } } @@ -198,16 +206,22 @@ func (c *Client) addLink( dnsServer, socksProxyAddr string, socksProxyPort int, + videoWidth, videoHeight, videoFPS int, + videoBitrate string, ) error { ln, err := link.New(ctx, linkName, link.Config{ - Transport: transportName, - Carrier: carrierName, - RoomURL: roomURL, - Name: names.Generate(), - OnData: c.onData, - DNSServer: dnsServer, - ProxyAddr: socksProxyAddr, - ProxyPort: socksProxyPort, + Transport: transportName, + Carrier: carrierName, + RoomURL: roomURL, + Name: names.Generate(), + OnData: c.onData, + DNSServer: dnsServer, + ProxyAddr: socksProxyAddr, + ProxyPort: socksProxyPort, + VideoWidth: videoWidth, + VideoHeight: videoHeight, + VideoFPS: videoFPS, + VideoBitrate: videoBitrate, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) @@ -235,10 +249,7 @@ func (c *Client) addLink( ln.WatchConnection(ctx) }() - // Send initial reset to clean up any stale connections for this clientID on server - if err := c.mux.SendClientReset(); err != nil { - logger.Warnf("Failed to send initial client reset: %v", err) - } + c.sendClientResetAsync("initial") return nil } @@ -268,9 +279,17 @@ func (c *Client) handleLinkReconnect(linkID int) { }) c.mux.Reset() - if err := c.mux.SendClientReset(); err != nil { - logger.Warnf("Failed to send client reset after reconnect: %v", err) - } + c.sendClientResetAsync("reconnect") +} + +func (c *Client) sendClientResetAsync(source string) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + if err := c.mux.SendClientReset(); err != nil { + logger.Warnf("Failed to send client reset after %s: %v", source, err) + } + }() } func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) { diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go index 40b318b..058089a 100644 --- a/internal/link/direct/direct.go +++ b/internal/link/direct/direct.go @@ -16,13 +16,17 @@ type directLink struct { // New creates a direct link that forwards bytes to the selected transport. func New(ctx context.Context, cfg link.Config) (link.Link, error) { tr, err := transport.New(ctx, cfg.Transport, transport.Config{ - Carrier: cfg.Carrier, - RoomURL: cfg.RoomURL, - Name: cfg.Name, - OnData: cfg.OnData, - DNSServer: cfg.DNSServer, - ProxyAddr: cfg.ProxyAddr, - ProxyPort: cfg.ProxyPort, + Carrier: cfg.Carrier, + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + VideoWidth: cfg.VideoWidth, + VideoHeight: cfg.VideoHeight, + VideoFPS: cfg.VideoFPS, + VideoBitrate: cfg.VideoBitrate, }) if err != nil { return nil, fmt.Errorf("create transport for direct link: %w", err) diff --git a/internal/link/link.go b/internal/link/link.go index bb86890..8c02198 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -25,14 +25,18 @@ type Link interface { // Config holds common link configuration. type Config struct { - Transport string - Carrier string - RoomURL string - Name string - OnData func([]byte) - DNSServer string - ProxyAddr string - ProxyPort int + Transport string + Carrier string + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string } // Factory creates a link instance. diff --git a/internal/provider/jazz/peer.go b/internal/provider/jazz/peer.go index 04a6625..90c5228 100644 --- a/internal/provider/jazz/peer.go +++ b/internal/provider/jazz/peer.go @@ -44,6 +44,13 @@ type Peer struct { sendQueueClosed atomic.Bool onEnded func(string) sessionCloseCh chan struct{} + videoTrackMu sync.RWMutex + videoTracks []webrtc.TrackLocal + onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver) + subscriberReady atomic.Bool + publisherReady atomic.Bool + subscriberConn chan struct{} + publisherConn chan struct{} wg sync.WaitGroup groupID string } @@ -83,12 +90,55 @@ func NewPeer(ctx context.Context, roomID, name string, onData func([]byte)) (*Pe closeCh: make(chan struct{}), sessionCloseCh: make(chan struct{}), sendQueue: make(chan []byte, 5000), + subscriberConn: make(chan struct{}), + publisherConn: make(chan struct{}), }, nil } +func (p *Peer) resetMediaState() { + p.subscriberReady.Store(false) + p.publisherReady.Store(false) + p.subscriberConn = make(chan struct{}) + p.publisherConn = make(chan struct{}) +} + +func closeSignal(ch chan struct{}) { + select { + case <-ch: + default: + close(ch) + } +} + +func (p *Peer) hasLocalVideoTracks() bool { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + return len(p.videoTracks) > 0 +} + +func (p *Peer) videoTrackHandler() func(*webrtc.TrackRemote, *webrtc.RTPReceiver) { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + return p.onVideoTrack +} + +func (p *Peer) attachPendingVideoTracks() error { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + + for _, track := range p.videoTracks { + if _, err := p.pcPub.AddTrack(track); err != nil { + return fmt.Errorf("failed to add track: %w", err) + } + } + + return nil +} + // Connect starts the WebRTC connection process. func (p *Peer) Connect(ctx context.Context) error { p.closed.Store(false) + p.resetMediaState() config := webrtc.Configuration{ ICEServers: []webrtc.ICEServer{}, @@ -107,21 +157,39 @@ func (p *Peer) Connect(ctx context.Context) error { if err != nil { return fmt.Errorf("create subscriber pc: %w", err) } + p.pcSub.OnConnectionStateChange(p.onSubscriberConnectionStateChange) + p.pcSub.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeVideo { + return + } + + if cb := p.videoTrackHandler(); cb != nil { + cb(track, receiver) + } + }) p.pcPub, err = api.NewPeerConnection(config) if err != nil { return fmt.Errorf("create publisher pc: %w", err) } + p.pcPub.OnConnectionStateChange(p.onPublisherConnectionStateChange) - p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{ - Ordered: func() *bool { v := true; return &v }(), - }) - if err != nil { - return fmt.Errorf("create datachannel: %w", err) + if err := p.attachPendingVideoTracks(); err != nil { + return err } - dcReady := make(chan struct{}) - p.setupDataChannelHandlers(dcReady) + var dcReady chan struct{} + if p.onData != nil { + p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{ + Ordered: func() *bool { v := true; return &v }(), + }) + if err != nil { + return fmt.Errorf("create datachannel: %w", err) + } + + dcReady = make(chan struct{}) + p.setupDataChannelHandlers(dcReady) + } if err := p.dialWebSocket(); err != nil { return err @@ -137,14 +205,33 @@ func (p *Peer) Connect(ctx context.Context) error { p.handleSignaling(ctx) }() + if p.onData != nil { + select { + case <-dcReady: + return nil + case <-time.After(30 * time.Second): + return provider.ErrDataChannelTimeout + case <-ctx.Done(): + return fmt.Errorf("connect cancelled: %w", ctx.Err()) + } + } + + return p.waitForMediaReady(ctx, 30*time.Second) +} + +func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { - case <-dcReady: - return nil - case <-time.After(30 * time.Second): - return provider.ErrDataChannelTimeout + case <-p.subscriberConn: + case <-timer.C: + return fmt.Errorf("subscriber media timeout") case <-ctx.Done(): return fmt.Errorf("connect cancelled: %w", ctx.Err()) } + + return nil } func (p *Peer) dialWebSocket() error { @@ -224,12 +311,42 @@ func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}) { return } - dc.OnMessage(func(msg webrtc.DataChannelMessage) { - p.handleIncomingMessage(msg.Data, "subscriber") - }) + if p.onData != nil { + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.handleIncomingMessage(msg.Data, "subscriber") + }) + } }) } +func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateConnected { + p.subscriberReady.Store(true) + closeSignal(p.subscriberConn) + } else if state == webrtc.PeerConnectionStateDisconnected || + state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed { + p.subscriberReady.Store(false) + if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) { + p.queueReconnect() + } + } +} + +func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateConnected { + p.publisherReady.Store(true) + closeSignal(p.publisherConn) + } else if state == webrtc.PeerConnectionStateDisconnected || + state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed { + p.publisherReady.Store(false) + if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) { + p.queueReconnect() + } + } +} + func (p *Peer) handleIncomingMessage(data []byte, source string) { logger.Verbosef("[Jazz] Received %d bytes on %s DC (raw)", len(data), source) @@ -535,20 +652,30 @@ func (p *Peer) Close() error { } var ( - // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. + // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized") ) // AddVideoTrack adds a video track to the publisher peer connection. -func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error { + p.videoTrackMu.Lock() + p.videoTracks = append(p.videoTracks, track) + p.videoTrackMu.Unlock() + if p.pcPub == nil { - return nil, ErrPublisherNotInitialized + return nil } - sender, err := p.pcPub.AddTrack(track) - if err != nil { - return nil, fmt.Errorf("failed to add track: %w", err) + if _, err := p.pcPub.AddTrack(track); err != nil { + return fmt.Errorf("failed to add track: %w", err) } - return sender, nil + return nil +} + +// SetVideoTrackHandler registers a callback for remote video tracks. +func (p *Peer) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + p.videoTrackMu.Lock() + defer p.videoTrackMu.Unlock() + p.onVideoTrack = cb } // SetReconnectCallback sets the callback for reconnection events. @@ -581,6 +708,12 @@ func (p *Peer) WatchConnection(ctx context.Context) { // CanSend checks if data can be sent. func (p *Peer) CanSend() bool { + if p.onData == nil { + if p.hasLocalVideoTracks() { + return !p.closed.Load() && p.subscriberReady.Load() && p.publisherReady.Load() + } + return !p.closed.Load() && p.subscriberReady.Load() + } if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { return false } diff --git a/internal/provider/jazz/provider.go b/internal/provider/jazz/provider.go index 23a5715..e5c8f8c 100644 --- a/internal/provider/jazz/provider.go +++ b/internal/provider/jazz/provider.go @@ -74,6 +74,11 @@ func (j *jazzProvider) GetBufferedAmount() uint64 { } // AddVideoTrack adds a video track to the jazz connection. -func (j *jazzProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (j *jazzProvider) AddVideoTrack(track webrtc.TrackLocal) error { return j.peer.AddVideoTrack(track) } + +// SetVideoTrackHandler registers a callback for subscribed remote video tracks. +func (j *jazzProvider) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + j.peer.SetVideoTrackHandler(cb) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index bbc7497..3820b85 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -37,9 +37,10 @@ type Provider interface { GetBufferedAmount() uint64 } -// VideoTrackCapable is implemented by providers that can publish video tracks. +// VideoTrackCapable is implemented by providers that can exchange video tracks. type VideoTrackCapable interface { - AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) + AddVideoTrack(track webrtc.TrackLocal) error + SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) } // Config holds common configuration for all providers. diff --git a/internal/provider/telemost/peer.go b/internal/provider/telemost/peer.go index 57cc785..72a5820 100644 --- a/internal/provider/telemost/peer.go +++ b/internal/provider/telemost/peer.go @@ -10,6 +10,7 @@ import ( "log" "math/rand/v2" "net/http" + "runtime" "strings" "sync" "sync/atomic" @@ -81,6 +82,13 @@ type Peer struct { onEnded func(string) trafficShape TrafficShape sessionCloseCh chan struct{} + videoTrackMu sync.RWMutex + videoTracks []webrtc.TrackLocal + onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver) + subscriberReady atomic.Bool + publisherReady atomic.Bool + subscriberConn chan struct{} + publisherConn chan struct{} wg sync.WaitGroup } @@ -132,6 +140,8 @@ func NewPeer(ctx context.Context, roomURL, name string, onData func([]byte)) (*P telemetryCh: make(chan struct{}, 1), sendQueue: make(chan []byte, 5000), ackWaiters: make(map[string]chan struct{}), + subscriberConn: make(chan struct{}), + publisherConn: make(chan struct{}), trafficShape: TrafficShape{ MaxMessageSize: realDataChannelMessageLimit, MinDelay: defaultSendDelayLow, @@ -182,6 +192,38 @@ func (p *Peer) resetSession() (chan struct{}, chan struct{}) { return p.keepAliveCh, p.sessionCloseCh } +func (p *Peer) resetMediaState() { + p.subscriberReady.Store(false) + p.publisherReady.Store(false) + p.subscriberConn = make(chan struct{}) + p.publisherConn = make(chan struct{}) +} + +func (p *Peer) hasLocalVideoTracks() bool { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + return len(p.videoTracks) > 0 +} + +func (p *Peer) videoTrackHandler() func(*webrtc.TrackRemote, *webrtc.RTPReceiver) { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + return p.onVideoTrack +} + +func (p *Peer) attachPendingVideoTracks() error { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + + for _, track := range p.videoTracks { + if _, err := p.pcPub.AddTrack(track); err != nil { + return fmt.Errorf("add video track: %w", err) + } + } + + return nil +} + func (p *Peer) drainReconnectQueue() { for { select { @@ -195,6 +237,7 @@ func (p *Peer) drainReconnectQueue() { // Connect starts the WebRTC connection process. func (p *Peer) Connect(ctx context.Context) error { p.closed.Store(false) + p.resetMediaState() config := webrtc.Configuration{ ICEServers: []webrtc.ICEServer{{URLs: []string{"stun:stun.rtc.yandex.net:3478"}}}, @@ -205,15 +248,18 @@ func (p *Peer) Connect(ctx context.Context) error { return err } - var err error - p.dc, err = p.pcPub.CreateDataChannel("olcrtc", nil) - if err != nil { - return fmt.Errorf("create dc: %w", err) - } - - dcReady := make(chan struct{}) keepAliveCh, sessionCloseCh := p.resetSession() - p.setupDataChannelHandlers(dcReady, sessionCloseCh) + var dcReady chan struct{} + if p.onData != nil { + var err error + p.dc, err = p.pcPub.CreateDataChannel("olcrtc", nil) + if err != nil { + return fmt.Errorf("create dc: %w", err) + } + + dcReady = make(chan struct{}) + p.setupDataChannelHandlers(dcReady, sessionCloseCh) + } if err := p.dialWebSocket(); err != nil { return err @@ -222,14 +268,33 @@ func (p *Peer) Connect(ctx context.Context) error { p.setupICEHandlers() p.startBackgroundGoroutines(ctx, keepAliveCh) + if p.onData != nil { + select { + case <-dcReady: + return nil + case <-time.After(15 * time.Second): + return ErrDataChannelTimeout + case <-ctx.Done(): + return fmt.Errorf("connect context cancelled: %w", ctx.Err()) + } + } + + return p.waitForMediaReady(ctx, 20*time.Second) +} + +func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { - case <-dcReady: - return nil - case <-time.After(15 * time.Second): - return ErrDataChannelTimeout + case <-p.subscriberConn: + case <-timer.C: + return fmt.Errorf("subscriber media timeout") case <-ctx.Done(): return fmt.Errorf("connect context cancelled: %w", ctx.Err()) } + + return nil } func (p *Peer) setupPeerConnections(config webrtc.Configuration) error { @@ -244,13 +309,28 @@ func (p *Peer) setupPeerConnections(config webrtc.Configuration) error { if err != nil { return fmt.Errorf("new sub pc: %w", err) } - p.pcSub.OnConnectionStateChange(p.onConnectionStateChange) + p.pcSub.OnConnectionStateChange(p.onSubscriberConnectionStateChange) + p.pcSub.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeVideo { + return + } + + logger.Infof("telemost remote video track: codec=%s stream=%s track=%s", track.Codec().MimeType, track.StreamID(), track.ID()) + + if cb := p.videoTrackHandler(); cb != nil { + cb(track, receiver) + } + }) p.pcPub, err = api.NewPeerConnection(config) if err != nil { return fmt.Errorf("new pub pc: %w", err) } - p.pcPub.OnConnectionStateChange(p.onConnectionStateChange) + p.pcPub.OnConnectionStateChange(p.onPublisherConnectionStateChange) + + if err := p.attachPendingVideoTracks(); err != nil { + return err + } return nil } @@ -262,6 +342,34 @@ func (p *Peer) onConnectionStateChange(state webrtc.PeerConnectionState) { } } +func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) { + logger.Debugf("telemost subscriber state: %s", state.String()) + if state == webrtc.PeerConnectionStateConnected { + p.subscriberReady.Store(true) + closeSignal(p.subscriberConn) + } else if state == webrtc.PeerConnectionStateDisconnected || + state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed { + p.subscriberReady.Store(false) + } + + p.onConnectionStateChange(state) +} + +func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) { + logger.Debugf("telemost publisher state: %s", state.String()) + if state == webrtc.PeerConnectionStateConnected { + p.publisherReady.Store(true) + closeSignal(p.publisherConn) + } else if state == webrtc.PeerConnectionStateDisconnected || + state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed { + p.publisherReady.Store(false) + } + + p.onConnectionStateChange(state) +} + func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}, sessionCloseCh chan struct{}) { p.dc.OnOpen(func() { numWorkers := 4 @@ -284,7 +392,9 @@ func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}, sessionCloseCh ch p.dc.OnMessage(p.onDataChannelMessage) p.pcSub.OnDataChannel(func(dc *webrtc.DataChannel) { - dc.OnMessage(p.onDataChannelMessage) + if p.onData != nil { + dc.OnMessage(p.onDataChannelMessage) + } }) } @@ -361,38 +471,35 @@ func (p *Peer) sendHello() error { "uid": uuid.New().String(), "hello": map[string]interface{}{ "participantMeta": map[string]interface{}{ - "name": p.name, - "role": "SPEAKER", - "sendAudio": false, - "sendVideo": false, + "name": p.name, + "role": "SPEAKER", + "description": "", + "sendAudio": false, + "sendVideo": p.hasLocalVideoTracks(), }, "participantAttributes": map[string]interface{}{ - "name": p.name, - "role": "SPEAKER", - }, - "sendAudio": false, - "sendVideo": false, - "sendSharing": false, - "participantId": p.conn.PeerID, - "roomId": p.conn.RoomID, - "serviceName": "telemost", - "credentials": p.conn.Credentials, - "capabilitiesOffer": map[string]interface{}{ - "offerAnswerMode": []string{"SEPARATE"}, - "initialSubscriberOffer": []string{"ON_HELLO"}, - "slotsMode": []string{"FROM_CONTROLLER"}, - "simulcastMode": []string{"DISABLED"}, - "selfVadStatus": []string{"FROM_SERVER"}, - "dataChannelSharing": []string{"TO_RTP"}, + "name": p.name, + "role": "SPEAKER", + "description": "", }, + "sendAudio": false, + "sendVideo": p.hasLocalVideoTracks(), + "sendSharing": false, + "participantId": p.conn.PeerID, + "roomId": p.conn.RoomID, + "serviceName": "telemost", + "credentials": p.conn.Credentials, + "capabilitiesOffer": telemostCapabilitiesOffer(), "sdkInfo": map[string]interface{}{ - "implementation": "go", - "version": "1.0.0", - "userAgent": "OlcRTC-" + p.name, + "implementation": "browser", + "version": "5.27.0", + "userAgent": "Mozilla/5.0 (X11; Linux x86_64; rv:149.0) Gecko/20100101 Firefox/149.0", + "hwConcurrency": runtime.NumCPU(), }, - "sdkInitializationId": uuid.New().String(), - "disablePublisher": false, - "disableSubscriber": false, + "sdkInitializationId": uuid.New().String(), + "disablePublisher": !p.hasLocalVideoTracks(), + "disableSubscriber": false, + "disableSubscriberAudio": true, }, } @@ -445,6 +552,7 @@ func (p *Peer) handleMessageEvents(ctx context.Context, msg map[string]interface } if serverHello, ok := msg["serverHello"].(map[string]interface{}); ok { + p.applyServerHelloConfig(serverHello) p.startTelemetry(ctx, serverHello) p.sendAck(uid) } @@ -516,6 +624,13 @@ func (p *Peer) handleSdpOffer(offer map[string]interface{}, uid string) error { p.wsMu.Unlock() p.sendAck(uid) + + if p.onData == nil { + if err := p.sendSetSlots(); err != nil { + logger.Debugf("setSlots error: %v", err) + } + } + time.Sleep(300 * time.Millisecond) pubOffer, err := p.pcPub.CreateOffer(nil) @@ -531,14 +646,200 @@ func (p *Peer) handleSdpOffer(offer map[string]interface{}, uid string) error { _ = p.ws.WriteJSON(map[string]interface{}{ "uid": uuid.New().String(), "publisherSdpOffer": map[string]interface{}{ - "pcSeq": 1, - "sdp": pubOffer.SDP, + "pcSeq": 1, + "sdp": pubOffer.SDP, + "tracks": p.publisherTrackDescriptions(), }, }) p.wsMu.Unlock() return nil } +func (p *Peer) sendSetSlots() error { + p.wsMu.Lock() + defer p.wsMu.Unlock() + + return p.ws.WriteJSON(map[string]interface{}{ + "uid": uuid.New().String(), + "setSlots": map[string]interface{}{ + "slots": []map[string]int{ + {"width": 1280, "height": 720}, + {"width": 640, "height": 360}, + }, + "audioSlotsCount": 0, + "key": 1, + "shutdownAllVideo": nil, + "withSelfView": false, + "selfViewVisibility": "ON_LOADING_THEN_SHOW", + "gridConfig": map[string]interface{}{}, + }, + }) +} + +func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) { + rawCfg, ok := serverHello["rtcConfiguration"].(map[string]interface{}) + if !ok { + return + } + + rawServers, ok := rawCfg["iceServers"].([]interface{}) + if !ok || len(rawServers) == 0 { + return + } + + iceServers := make([]webrtc.ICEServer, 0, len(rawServers)) + for _, rawServer := range rawServers { + server, ok := rawServer.(map[string]interface{}) + if !ok { + continue + } + + var urls []string + switch rawURLs := server["urls"].(type) { + case []interface{}: + for _, rawURL := range rawURLs { + if url, ok := rawURL.(string); ok && url != "" { + urls = append(urls, url) + } + } + case []string: + urls = append(urls, rawURLs...) + } + + if len(urls) == 0 { + continue + } + + ice := webrtc.ICEServer{URLs: urls} + if username, ok := server["username"].(string); ok { + ice.Username = username + } + if credential, ok := server["credential"].(string); ok { + ice.Credential = credential + } + iceServers = append(iceServers, ice) + } + + if len(iceServers) == 0 { + return + } + + cfg := webrtc.Configuration{ + ICEServers: iceServers, + SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, + } + + if p.pcSub != nil { + _ = p.pcSub.SetConfiguration(cfg) + } + if p.pcPub != nil { + _ = p.pcPub.SetConfiguration(cfg) + } +} + +func (p *Peer) publisherTrackDescriptions() []map[string]interface{} { + if p.pcPub == nil { + return nil + } + + tracks := make([]map[string]interface{}, 0) + for _, transceiver := range p.pcPub.GetTransceivers() { + sender := transceiver.Sender() + if sender == nil { + continue + } + + track := sender.Track() + if track == nil { + continue + } + + kind := "VIDEO" + if track.Kind() == webrtc.RTPCodecTypeAudio { + kind = "AUDIO" + } + + tracks = append(tracks, map[string]interface{}{ + "mid": transceiver.Mid(), + "transceiverMid": transceiver.Mid(), + "kind": kind, + "priority": 0, + "label": track.ID(), + "codecs": map[string]interface{}{}, + "groupId": 1, + "description": "", + }) + } + + return tracks +} + +func telemostCapabilitiesOffer() map[string]interface{} { + return map[string]interface{}{ + "offerAnswerMode": []string{"SEPARATE"}, + "initialSubscriberOffer": []string{"ON_HELLO"}, + "slotsMode": []string{"FROM_CONTROLLER"}, + "simulcastMode": []string{"DISABLED", "STATIC"}, + "selfVadStatus": []string{"FROM_SERVER", "FROM_CLIENT"}, + "dataChannelSharing": []string{"TO_RTP"}, + "videoEncoderConfig": []string{"NO_CONFIG", "ONLY_INIT_CONFIG", "RUNTIME_CONFIG"}, + "dataChannelVideoCodec": []string{"VP8", "UNIQUE_CODEC_FROM_TRACK_DESCRIPTION"}, + "bandwidthLimitationReason": []string{ + "BANDWIDTH_REASON_DISABLED", + "BANDWIDTH_REASON_ENABLED", + }, + "sdkDefaultDeviceManagement": []string{ + "SDK_DEFAULT_DEVICE_MANAGEMENT_DISABLED", + "SDK_DEFAULT_DEVICE_MANAGEMENT_ENABLED", + }, + "joinOrderLayout": []string{"JOIN_ORDER_LAYOUT_DISABLED", "JOIN_ORDER_LAYOUT_ENABLED"}, + "pinLayout": []string{"PIN_LAYOUT_DISABLED"}, + "sendSelfViewVideoSlot": []string{ + "SEND_SELF_VIEW_VIDEO_SLOT_DISABLED", + "SEND_SELF_VIEW_VIDEO_SLOT_ENABLED", + }, + "serverLayoutTransition": []string{"SERVER_LAYOUT_TRANSITION_DISABLED"}, + "sdkPublisherOptimizeBitrate": []string{ + "SDK_PUBLISHER_OPTIMIZE_BITRATE_DISABLED", + "SDK_PUBLISHER_OPTIMIZE_BITRATE_FULL", + "SDK_PUBLISHER_OPTIMIZE_BITRATE_ONLY_SELF", + }, + "sdkNetworkLostDetection": []string{"SDK_NETWORK_LOST_DETECTION_DISABLED"}, + "sdkNetworkPathMonitor": []string{"SDK_NETWORK_PATH_MONITOR_DISABLED"}, + "publisherVp9": []string{"PUBLISH_VP9_DISABLED", "PUBLISH_VP9_ENABLED"}, + "svcMode": []string{"SVC_MODE_DISABLED", "SVC_MODE_L3T3", "SVC_MODE_L3T3_KEY"}, + "subscriberOfferAsyncAck": []string{"SUBSCRIBER_OFFER_ASYNC_ACK_DISABLED", "SUBSCRIBER_OFFER_ASYNC_ACK_ENABLED"}, + "androidBluetoothRoutingFix": []string{ + "ANDROID_BLUETOOTH_ROUTING_FIX_DISABLED", + }, + "fixedIceCandidatesPoolSize": []string{ + "FIXED_ICE_CANDIDATES_POOL_SIZE_DISABLED", + }, + "sdkAndroidTelecomIntegration": []string{ + "SDK_ANDROID_TELECOM_INTEGRATION_DISABLED", + }, + "setActiveCodecsMode": []string{ + "SET_ACTIVE_CODECS_MODE_DISABLED", + "SET_ACTIVE_CODECS_MODE_VIDEO_ONLY", + }, + "subscriberDtlsPassiveMode": []string{ + "SUBSCRIBER_DTLS_PASSIVE_MODE_DISABLED", + }, + "publisherOpusDred": []string{ + "PUBLISHER_OPUS_DRED_DISABLED", + }, + "publisherOpusLowBitrate": []string{ + "PUBLISHER_OPUS_LOW_BITRATE_DISABLED", + }, + "sdkAndroidDestroySessionOnTaskRemoved": []string{ + "SDK_ANDROID_DESTROY_SESSION_ON_TASK_REMOVED_DISABLED", + }, + "svcModes": []string{"FALSE"}, + "reportTelemetryModes": []string{"TRUE"}, + "keepDefaultDevicesModes": []string{"FALSE"}, + } +} + func (p *Peer) handleSdpAnswer(answer map[string]interface{}, uid string) { sdp, _ := answer["sdp"].(string) if err := p.pcPub.SetRemoteDescription(webrtc.SessionDescription{ @@ -1155,6 +1456,12 @@ func (p *Peer) monitorQueue(sessionCloseCh <-chan struct{}) { // CanSend checks if data can be sent. func (p *Peer) CanSend() bool { + if p.onData == nil { + if p.hasLocalVideoTracks() { + return !p.closed.Load() && p.subscriberReady.Load() && p.publisherReady.Load() + } + return !p.closed.Load() && p.subscriberReady.Load() + } if p.dc == nil || p.dc.ReadyState() != webrtc.DataChannelStateOpen { return false } @@ -1162,18 +1469,28 @@ func (p *Peer) CanSend() bool { } var ( - // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. + // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized") ) // AddVideoTrack adds a video track to the publisher peer connection. -func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error { + p.videoTrackMu.Lock() + p.videoTracks = append(p.videoTracks, track) + p.videoTrackMu.Unlock() + if p.pcPub == nil { - return nil, ErrPublisherNotInitialized + return nil } - sender, err := p.pcPub.AddTrack(track) - if err != nil { - return nil, fmt.Errorf("failed to add track: %w", err) + if _, err := p.pcPub.AddTrack(track); err != nil { + return fmt.Errorf("failed to add track: %w", err) } - return sender, nil + return nil +} + +// SetVideoTrackHandler registers a callback for remote video tracks. +func (p *Peer) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + p.videoTrackMu.Lock() + defer p.videoTrackMu.Unlock() + p.onVideoTrack = cb } diff --git a/internal/provider/telemost/provider.go b/internal/provider/telemost/provider.go index 90272f9..c9ee6f2 100644 --- a/internal/provider/telemost/provider.go +++ b/internal/provider/telemost/provider.go @@ -74,7 +74,11 @@ func (t *telemostProvider) GetBufferedAmount() uint64 { } // AddVideoTrack adds a video track to the telemost connection. -func (t *telemostProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (t *telemostProvider) AddVideoTrack(track webrtc.TrackLocal) error { return t.peer.AddVideoTrack(track) } +// SetVideoTrackHandler registers a callback for subscribed remote video tracks. +func (t *telemostProvider) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + t.peer.SetVideoTrackHandler(cb) +} diff --git a/internal/provider/wbstream/peer.go b/internal/provider/wbstream/peer.go index 51d80cb..81225e7 100644 --- a/internal/provider/wbstream/peer.go +++ b/internal/provider/wbstream/peer.go @@ -24,8 +24,6 @@ var ( ErrSendQueueFull = errors.New("send queue full") // ErrLiveKitNotConnected is returned when the LiveKit room is not connected. ErrLiveKitNotConnected = errors.New("livekit room not connected") - // ErrVideoNotSupported is returned when video tracks are not supported by this provider. - ErrVideoNotSupported = errors.New("video tracks not supported yet in wbstream") ) // Peer represents a WB Stream WebRTC connection using LiveKit. @@ -41,6 +39,9 @@ type Peer struct { closed atomic.Bool done chan struct{} cancel context.CancelFunc + videoTrackMu sync.RWMutex + videoTracks []webrtc.TrackLocal + onVideoTrack func(*webrtc.TrackRemote, *webrtc.RTPReceiver) wg sync.WaitGroup } @@ -71,6 +72,18 @@ func (p *Peer) Connect(ctx context.Context) error { p.onData(data) } }, + OnTrackSubscribed: func(track *webrtc.TrackRemote, _ *lksdk.RemoteTrackPublication, _ *lksdk.RemoteParticipant) { + if track.Kind() != webrtc.RTPCodecTypeVideo { + return + } + + p.videoTrackMu.RLock() + cb := p.onVideoTrack + p.videoTrackMu.RUnlock() + if cb != nil { + cb(track, nil) + } + }, }, OnDisconnected: func() { if p.onEnded != nil { @@ -85,12 +98,30 @@ func (p *Peer) Connect(ctx context.Context) error { } p.room = room + if err := p.publishPendingTracks(); err != nil { + return err + } p.wg.Add(1) go p.processSendQueue() return nil } +func (p *Peer) publishPendingTracks() error { + p.videoTrackMu.RLock() + defer p.videoTrackMu.RUnlock() + + for _, track := range p.videoTracks { + if _, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ + Name: "videochannel", + }); err != nil { + return fmt.Errorf("failed to publish track: %w", err) + } + } + + return nil +} + func (p *Peer) getRoomToken(ctx context.Context) (string, error) { accessToken, err := registerGuest(ctx, p.name) if err != nil { @@ -201,17 +232,27 @@ func (p *Peer) GetBufferedAmount() uint64 { } // AddVideoTrack adds a video track to the LiveKit room. -func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error { + p.videoTrackMu.Lock() + p.videoTracks = append(p.videoTracks, track) + p.videoTrackMu.Unlock() + if p.room == nil || p.room.LocalParticipant == nil { - return nil, ErrLiveKitNotConnected + return nil } - _, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ - Name: "video", - }) - if err != nil { - return nil, fmt.Errorf("failed to publish track: %w", err) + if _, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{ + Name: "videochannel", + }); err != nil { + return fmt.Errorf("failed to publish track: %w", err) } - return nil, ErrVideoNotSupported + return nil +} + +// SetVideoTrackHandler registers a callback for remote video tracks. +func (p *Peer) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + p.videoTrackMu.Lock() + defer p.videoTrackMu.Unlock() + p.onVideoTrack = cb } diff --git a/internal/provider/wbstream/provider.go b/internal/provider/wbstream/provider.go index 79fec22..a6ebbaa 100644 --- a/internal/provider/wbstream/provider.go +++ b/internal/provider/wbstream/provider.go @@ -74,7 +74,11 @@ func (w *wbStreamProvider) GetBufferedAmount() uint64 { } // AddVideoTrack adds a video track to the wbstream connection. -func (w *wbStreamProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { +func (w *wbStreamProvider) AddVideoTrack(track webrtc.TrackLocal) error { return w.peer.AddVideoTrack(track) } +// SetVideoTrackHandler registers a callback for subscribed remote video tracks. +func (w *wbStreamProvider) SetVideoTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + w.peer.SetVideoTrackHandler(cb) +} diff --git a/internal/server/server.go b/internal/server/server.go index 6a8e25c..c88bb8a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -76,6 +76,10 @@ func Run( dnsServer, socksProxyAddr string, socksProxyPort int, + videoWidth int, + videoHeight int, + videoFPS int, + videoBitrate string, ) error { runCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -104,7 +108,7 @@ func Run( const linkCount = 1 for i := range linkCount { - if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel); err != nil { + if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate); err != nil { return fmt.Errorf("addLink failed: %w", err) } } @@ -194,16 +198,22 @@ func (s *Server) addLink( roomURL string, linkID int, cancel context.CancelFunc, + videoWidth, videoHeight, videoFPS int, + videoBitrate string, ) error { ln, err := link.New(ctx, linkName, link.Config{ - Transport: transportName, - Carrier: carrierName, - RoomURL: roomURL, - Name: names.Generate(), - OnData: s.onData, - DNSServer: s.dnsServer, - ProxyAddr: s.socksProxyAddr, - ProxyPort: s.socksProxyPort, + Transport: transportName, + Carrier: carrierName, + RoomURL: roomURL, + Name: names.Generate(), + OnData: s.onData, + DNSServer: s.dnsServer, + ProxyAddr: s.socksProxyAddr, + ProxyPort: s.socksProxyPort, + VideoWidth: videoWidth, + VideoHeight: videoHeight, + VideoFPS: videoFPS, + VideoBitrate: videoBitrate, }) if err != nil { return fmt.Errorf("failed to create link: %w", err) diff --git a/internal/transport/seichannel/h264.go b/internal/transport/seichannel/h264.go new file mode 100644 index 0000000..1d6f993 --- /dev/null +++ b/internal/transport/seichannel/h264.go @@ -0,0 +1,188 @@ +package seichannel + +import ( + "bytes" + "encoding/hex" + "fmt" + + "github.com/pion/webrtc/v4/pkg/media/h264reader" +) + +var ( + videoSEIUUID = [16]byte{ + 0x5d, 0xc0, 0x3b, 0xa8, + 0x45, 0x0f, + 0x4b, 0x55, + 0x9a, 0x77, + 0x1f, 0x91, 0x6c, 0x5b, 0x07, 0x39, + } + baseSPS = mustDecodeHex("6742c00addec0440000003004000000300a3c489e0") + basePPS = mustDecodeHex("68ce0fc8") + baseIDR = mustDecodeHex("6588843a2628000902e0") +) + +func buildVideoAccessUnit(payload []byte) ([]byte, error) { + out := make([]byte, 0, len(baseSPS)+len(basePPS)+len(baseIDR)+64+len(payload)) + out = appendStartCode(out, baseSPS) + out = appendStartCode(out, basePPS) + if len(payload) > 0 { + sei, err := buildSEINAL(payload) + if err != nil { + return nil, err + } + out = appendStartCode(out, sei) + } + out = appendStartCode(out, baseIDR) + return out, nil +} + +func extractVideoPayloads(accessUnit []byte) ([][]byte, error) { + reader, err := h264reader.NewReaderWithOptions(bytes.NewReader(accessUnit), h264reader.WithIncludeSEI(true)) + if err != nil { + return nil, fmt.Errorf("create h264 reader: %w", err) + } + + payloads := make([][]byte, 0, 1) + for { + nal, readErr := reader.NextNAL() + if readErr != nil { + if len(payloads) == 0 { + return nil, nil + } + return payloads, nil + } + if nal == nil || nal.UnitType != h264reader.NalUnitTypeSEI || len(nal.Data) < 2 { + continue + } + + found, err := extractTransportSEI(nal.Data[1:]) + if err != nil { + continue + } + payloads = append(payloads, found...) + } +} + +func buildSEINAL(payload []byte) ([]byte, error) { + userData := make([]byte, 0, len(videoSEIUUID)+len(payload)) + userData = append(userData, videoSEIUUID[:]...) + userData = append(userData, payload...) + + rbsp := make([]byte, 0, len(userData)+8) + rbsp = appendSEIValue(rbsp, 5) + rbsp = appendSEIValue(rbsp, len(userData)) + rbsp = append(rbsp, userData...) + rbsp = append(rbsp, 0x80) + + out := []byte{0x06} + out = append(out, escapeRBSP(rbsp)...) + return out, nil +} + +func extractTransportSEI(rbsp []byte) ([][]byte, error) { + data := unescapeRBSP(rbsp) + out := make([][]byte, 0, 1) + + for pos := 0; pos < len(data); { + if data[pos] == 0x80 && pos == len(data)-1 { + break + } + + payloadType, next, err := consumeSEIValue(data, pos) + if err != nil { + return nil, err + } + pos = next + + payloadSize, next, err := consumeSEIValue(data, pos) + if err != nil { + return nil, err + } + pos = next + + if pos+payloadSize > len(data) { + return nil, fmt.Errorf("sei payload truncated") + } + + payload := data[pos : pos+payloadSize] + pos += payloadSize + + if payloadType != 5 || len(payload) < len(videoSEIUUID) { + continue + } + if !bytes.Equal(payload[:len(videoSEIUUID)], videoSEIUUID[:]) { + continue + } + + frame := make([]byte, len(payload)-len(videoSEIUUID)) + copy(frame, payload[len(videoSEIUUID):]) + out = append(out, frame) + } + + return out, nil +} + +func appendSEIValue(dst []byte, value int) []byte { + for value >= 0xff { + dst = append(dst, 0xff) + value -= 0xff + } + return append(dst, byte(value)) +} + +func consumeSEIValue(data []byte, pos int) (int, int, error) { + value := 0 + for { + if pos >= len(data) { + return 0, pos, fmt.Errorf("sei value truncated") + } + b := int(data[pos]) + pos++ + value += b + if b != 0xff { + return value, pos, nil + } + } +} + +func appendStartCode(dst, nalu []byte) []byte { + dst = append(dst, 0x00, 0x00, 0x00, 0x01) + return append(dst, nalu...) +} + +func escapeRBSP(rbsp []byte) []byte { + out := make([]byte, 0, len(rbsp)+8) + zeroCount := 0 + for _, b := range rbsp { + if zeroCount >= 2 && b <= 0x03 { + out = append(out, 0x03) + zeroCount = 0 + } + out = append(out, b) + if b == 0x00 { + zeroCount++ + } else { + zeroCount = 0 + } + } + return out +} + +func unescapeRBSP(rbsp []byte) []byte { + out := make([]byte, 0, len(rbsp)) + for i := 0; i < len(rbsp); i++ { + if i >= 2 && rbsp[i] == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 { + continue + } + out = append(out, rbsp[i]) + } + return out +} + +func mustDecodeHex(value string) []byte { + data, err := hex.DecodeString(value) + if err != nil { + panic(err) + } + return data +} diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go new file mode 100644 index 0000000..19bb40b --- /dev/null +++ b/internal/transport/seichannel/transport.go @@ -0,0 +1,534 @@ +// Package seichannel provides a byte transport over H264 SEI messages. +package seichannel + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "sync" + "sync/atomic" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/pion/rtp/codecs" + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media" + "github.com/pion/webrtc/v4/pkg/media/samplebuilder" +) + +const ( + defaultMaxPayloadSize = 7 * 1024 + defaultFragmentSize = 900 + defaultAckTimeout = 3 * time.Second + defaultFrameInterval = 50 * time.Millisecond + defaultConnectTimeout = 30 * time.Second + maxSendAttempts = 4 + sampleBuilderMaxLate = 128 + protocolMagic uint32 = 0x4f564331 // OVC1 + protocolVersion byte = 1 + frameTypeData byte = 1 + frameTypeAck byte = 2 +) + +var ( + // ErrVideoTrackUnsupported is returned when a carrier cannot expose video tracks. + ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") + // ErrAckTimeout is returned when the peer does not acknowledge a payload in time. + ErrAckTimeout = errors.New("seichannel ack timeout") + // ErrTransportClosed is returned when operations are attempted on a closed transport. + ErrTransportClosed = errors.New("seichannel transport closed") +) + +type transportFrame struct { + typ byte + seq uint32 + crc uint32 + totalLen uint32 + fragIdx uint16 + fragTotal uint16 + payload []byte +} + +type inboundMessage struct { + totalLen uint32 + crc uint32 + frags [][]byte + remain int +} + +type streamTransport struct { + stream carrier.VideoTrack + track *webrtc.TrackLocalStaticSample + onData func([]byte) + outbound chan []byte + outboundAck chan []byte + closeCh chan struct{} + writerDone chan struct{} + nextSeq atomic.Uint32 + closed atomic.Bool + writerUp atomic.Bool + sendMu sync.Mutex + startWriter sync.Once + ackMu sync.Mutex + ackWaiters map[uint32]chan uint32 + recvMu sync.Mutex + inbound map[uint32]*inboundMessage + delivered map[uint32]uint32 +} + +// New creates a seichannel transport backed by a carrier-specific provider. +func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { + session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: nil, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("create provider transport: %w", err) + } + + videoCapable, ok := session.(carrier.VideoTrackCapable) + if !ok { + return nil, ErrVideoTrackUnsupported + } + + stream, err := videoCapable.OpenVideoTrack() + if err != nil { + return nil, fmt.Errorf("open video track: %w", err) + } + + track, err := webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH264, + ClockRate: 90000, + Channels: 0, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42c00a", + }, + "seichannel", + "olcrtc", + ) + if err != nil { + return nil, fmt.Errorf("create local video track: %w", err) + } + + tr := &streamTransport{ + stream: stream, + track: track, + onData: cfg.OnData, + outbound: make(chan []byte, 256), + outboundAck: make(chan []byte, 64), + closeCh: make(chan struct{}), + writerDone: make(chan struct{}), + ackWaiters: make(map[uint32]chan uint32), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + } + + if err := stream.AddTrack(track); err != nil { + return nil, fmt.Errorf("attach local video track: %w", err) + } + stream.SetTrackHandler(tr.handleRemoteTrack) + + return tr, nil +} + +// Connect starts the transport connection. +func (p *streamTransport) Connect(ctx context.Context) error { + connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout) + defer cancel() + + if err := p.stream.Connect(connectCtx); err != nil { + return err + } + + p.startWriter.Do(func() { + p.writerUp.Store(true) + go p.writerLoop() + }) + + return nil +} + +// Send transmits data through the transport. +func (p *streamTransport) Send(data []byte) error { + if p.closed.Load() { + return ErrTransportClosed + } + + p.sendMu.Lock() + defer p.sendMu.Unlock() + + seq := p.nextSeq.Add(1) + crc := crc32.ChecksumIEEE(data) + fragments := fragmentPayload(data, defaultFragmentSize) + waiter := make(chan uint32, 1) + + p.ackMu.Lock() + p.ackWaiters[seq] = waiter + p.ackMu.Unlock() + defer func() { + p.ackMu.Lock() + delete(p.ackWaiters, seq) + p.ackMu.Unlock() + }() + + for attempt := 0; attempt < maxSendAttempts; attempt++ { + for idx, fragment := range fragments { + frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment) + if err := p.enqueueFrame(frame, false); err != nil { + return err + } + } + + timer := time.NewTimer(defaultAckTimeout) + select { + case ackCRC := <-waiter: + timer.Stop() + if ackCRC == crc { + return nil + } + case <-timer.C: + case <-p.closeCh: + timer.Stop() + return ErrTransportClosed + } + } + + return ErrAckTimeout +} + +// Close terminates the transport. +func (p *streamTransport) Close() error { + if p.closed.CompareAndSwap(false, true) { + close(p.closeCh) + if p.writerUp.Load() { + <-p.writerDone + } + return p.stream.Close() + } + return nil +} + +// SetReconnectCallback registers reconnect handling. +func (p *streamTransport) SetReconnectCallback(cb func()) { + p.stream.SetReconnectCallback(cb) +} + +// SetShouldReconnect configures reconnect policy. +func (p *streamTransport) SetShouldReconnect(fn func() bool) { + p.stream.SetShouldReconnect(fn) +} + +// SetEndedCallback registers end-of-session handling. +func (p *streamTransport) SetEndedCallback(cb func(string)) { + p.stream.SetEndedCallback(cb) +} + +// WatchConnection monitors connection lifecycle. +func (p *streamTransport) WatchConnection(ctx context.Context) { + p.stream.WatchConnection(ctx) +} + +// CanSend reports whether transport is ready for sending. +func (p *streamTransport) CanSend() bool { + return !p.closed.Load() && p.stream.CanSend() +} + +// Features describes the current seichannel transport semantics. +func (p *streamTransport) Features() transport.Features { + return transport.Features{ + Reliable: true, + Ordered: true, + MessageOriented: true, + MaxPayloadSize: defaultMaxPayloadSize, + } +} + +func (p *streamTransport) writerLoop() { + defer close(p.writerDone) + + ticker := time.NewTicker(defaultFrameInterval) + defer ticker.Stop() + + idle, err := buildVideoAccessUnit(nil) + if err != nil { + return + } + + for { + select { + case <-p.closeCh: + return + case <-ticker.C: + payload, ok := p.nextOutboundFrame() + if !ok { + return + } + + sample := idle + if payload != nil { + sample, err = buildVideoAccessUnit(payload) + if err != nil { + continue + } + } + + _ = p.track.WriteSample(media.Sample{ + Data: sample, + Duration: defaultFrameInterval, + }) + } + } +} + +func (p *streamTransport) nextOutboundFrame() ([]byte, bool) { + select { + case <-p.closeCh: + return nil, false + case payload := <-p.outboundAck: + return payload, true + default: + } + + select { + case <-p.closeCh: + return nil, false + case payload := <-p.outboundAck: + return payload, true + case payload := <-p.outbound: + return payload, true + default: + return nil, true + } +} + +func (p *streamTransport) enqueueFrame(frame []byte, priority bool) error { + if p.closed.Load() { + return ErrTransportClosed + } + + ch := p.outbound + if priority { + ch = p.outboundAck + } + + select { + case <-p.closeCh: + return ErrTransportClosed + case ch <- frame: + return nil + } +} + +func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + go func() { + sb := samplebuilder.New(sampleBuilderMaxLate, &codecs.H264Packet{}, track.Codec().ClockRate) + + popSamples := func() { + for sample := sb.Pop(); sample != nil; sample = sb.Pop() { + p.handleSample(sample.Data) + } + } + + for { + packet, _, err := track.ReadRTP() + if err != nil { + sb.Flush() + popSamples() + return + } + + sb.Push(packet) + popSamples() + } + }() +} + +func (p *streamTransport) handleSample(sample []byte) { + payloads, err := extractVideoPayloads(sample) + if err != nil { + return + } + + for _, payload := range payloads { + frame, err := decodeTransportFrame(payload) + if err != nil { + continue + } + + switch frame.typ { + case frameTypeAck: + p.resolveAck(frame.seq, frame.crc) + case frameTypeData: + p.handleInboundFrame(frame) + } + } +} + +func (p *streamTransport) handleInboundFrame(frame transportFrame) { + p.recvMu.Lock() + if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { + p.recvMu.Unlock() + p.sendAck(frame.seq, frame.crc) + return + } + + msg, ok := p.inbound[frame.seq] + if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { + msg = &inboundMessage{ + totalLen: frame.totalLen, + crc: frame.crc, + frags: make([][]byte, frame.fragTotal), + remain: int(frame.fragTotal), + } + p.inbound[frame.seq] = msg + } + + if int(frame.fragIdx) >= len(msg.frags) { + p.recvMu.Unlock() + return + } + + if msg.frags[frame.fragIdx] == nil { + chunk := make([]byte, len(frame.payload)) + copy(chunk, frame.payload) + msg.frags[frame.fragIdx] = chunk + msg.remain-- + } + + if msg.remain > 0 { + p.recvMu.Unlock() + return + } + + delete(p.inbound, frame.seq) + data := make([]byte, 0, msg.totalLen) + for _, frag := range msg.frags { + data = append(data, frag...) + } + + if uint32(len(data)) > msg.totalLen { + data = data[:msg.totalLen] + } + + if crc32.ChecksumIEEE(data) != msg.crc { + p.recvMu.Unlock() + return + } + + if len(p.delivered) > 256 { + p.delivered = make(map[uint32]uint32) + } + p.delivered[frame.seq] = msg.crc + p.recvMu.Unlock() + + if p.onData != nil { + p.onData(data) + } + p.sendAck(frame.seq, frame.crc) +} + +func (p *streamTransport) sendAck(seq, crc uint32) { + _ = p.enqueueFrame(encodeAckFrame(seq, crc), true) +} + +func (p *streamTransport) resolveAck(seq, crc uint32) { + p.ackMu.Lock() + waiter := p.ackWaiters[seq] + p.ackMu.Unlock() + + if waiter == nil { + return + } + + select { + case waiter <- crc: + default: + } +} + +func fragmentPayload(data []byte, maxSize int) [][]byte { + if len(data) == 0 { + return [][]byte{{}} + } + + out := make([][]byte, 0, (len(data)+maxSize-1)/maxSize) + for start := 0; start < len(data); start += maxSize { + end := start + maxSize + if end > len(data) { + end = len(data) + } + + chunk := make([]byte, end-start) + copy(chunk, data[start:end]) + out = append(out, chunk) + } + + return out +} + +func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload []byte) []byte { + out := make([]byte, 22+len(payload)) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeData + binary.BigEndian.PutUint32(out[6:10], seq) + binary.BigEndian.PutUint32(out[10:14], crc) + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) + binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) + copy(out[22:], payload) + return out +} + +func encodeAckFrame(seq, crc uint32) []byte { + out := make([]byte, 14) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeAck + binary.BigEndian.PutUint32(out[6:10], seq) + binary.BigEndian.PutUint32(out[10:14], crc) + return out +} + +func decodeTransportFrame(data []byte) (transportFrame, error) { + if len(data) < 6 { + return transportFrame{}, fmt.Errorf("frame too short") + } + if binary.BigEndian.Uint32(data[0:4]) != protocolMagic { + return transportFrame{}, fmt.Errorf("unexpected frame magic") + } + if data[4] != protocolVersion { + return transportFrame{}, fmt.Errorf("unexpected frame version") + } + + frame := transportFrame{typ: data[5]} + switch frame.typ { + case frameTypeAck: + if len(data) < 14 { + return transportFrame{}, fmt.Errorf("ack too short") + } + frame.seq = binary.BigEndian.Uint32(data[6:10]) + frame.crc = binary.BigEndian.Uint32(data[10:14]) + return frame, nil + case frameTypeData: + if len(data) < 22 { + return transportFrame{}, fmt.Errorf("data too short") + } + frame.seq = binary.BigEndian.Uint32(data[6:10]) + frame.crc = binary.BigEndian.Uint32(data[10:14]) + frame.totalLen = binary.BigEndian.Uint32(data[14:18]) + frame.fragIdx = binary.BigEndian.Uint16(data[18:20]) + frame.fragTotal = binary.BigEndian.Uint16(data[20:22]) + frame.payload = append([]byte(nil), data[22:]...) + return frame, nil + default: + return transportFrame{}, fmt.Errorf("unexpected frame type") + } +} diff --git a/internal/transport/seichannel/transport_test.go b/internal/transport/seichannel/transport_test.go new file mode 100644 index 0000000..82d7c25 --- /dev/null +++ b/internal/transport/seichannel/transport_test.go @@ -0,0 +1,42 @@ +package seichannel + +import ( + "bytes" + "testing" +) + +func TestSEIRoundTrip(t *testing.T) { + payload := []byte("hello over seichannel") + accessUnit, err := buildVideoAccessUnit(payload) + if err != nil { + t.Fatalf("buildVideoAccessUnit failed: %v", err) + } + + got, err := extractVideoPayloads(accessUnit) + if err != nil { + t.Fatalf("extractVideoPayloads failed: %v", err) + } + if len(got) != 1 { + t.Fatalf("expected 1 payload, got %d", len(got)) + } + if !bytes.Equal(got[0], payload) { + t.Fatalf("payload mismatch: got=%q want=%q", got[0], payload) + } +} + +func TestTransportFrameRoundTrip(t *testing.T) { + encoded := encodeDataFrame(42, 0xdeadbeef, 1024, 1, 3, []byte("chunk")) + decoded, err := decodeTransportFrame(encoded) + if err != nil { + t.Fatalf("decodeTransportFrame failed: %v", err) + } + if decoded.typ != frameTypeData || decoded.seq != 42 || decoded.crc != 0xdeadbeef { + t.Fatalf("unexpected frame header: %+v", decoded) + } + if decoded.totalLen != 1024 || decoded.fragIdx != 1 || decoded.fragTotal != 3 { + t.Fatalf("unexpected fragmentation fields: %+v", decoded) + } + if !bytes.Equal(decoded.payload, []byte("chunk")) { + t.Fatalf("payload mismatch: got=%q", decoded.payload) + } +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 74c8c4c..a4906b3 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -34,13 +34,17 @@ type Transport interface { // Config holds common transport configuration. type Config struct { - Carrier string - RoomURL string - Name string - OnData func([]byte) - DNSServer string - ProxyAddr string - ProxyPort int + Carrier string + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string } // Factory creates a transport instance. diff --git a/internal/transport/videochannel/ffmpeg.go b/internal/transport/videochannel/ffmpeg.go new file mode 100644 index 0000000..9013761 --- /dev/null +++ b/internal/transport/videochannel/ffmpeg.go @@ -0,0 +1,444 @@ +package videochannel + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "os/exec" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/rtp" + "github.com/pion/rtp/codecs" + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media/ivfreader" +) + +const ( + ffmpegFrameTimeout = 10 * time.Second +) + +var ( + // ErrFFmpegUnavailable is returned when ffmpeg is not available on PATH. + ErrFFmpegUnavailable = errors.New("ffmpeg is required for videochannel") + // ErrUnsupportedVideoCodec is returned when videochannel cannot decode the negotiated codec. + ErrUnsupportedVideoCodec = errors.New("unsupported video codec") +) + +type codecSpec struct { + mimeType string + fourCC string + encoder string + capability webrtc.RTPCodecCapability + depacketizer func() rtp.Depacketizer + encodeArgs []string +} + +func codecSpecForCarrier(carrier string) codecSpec { + return vp8CodecSpec() +} + +func codecSpecForMime(mimeType string) (codecSpec, bool) { + switch strings.ToLower(mimeType) { + case strings.ToLower(webrtc.MimeTypeVP9): + return vp9CodecSpec(), true + case strings.ToLower(webrtc.MimeTypeVP8): + return vp8CodecSpec(), true + default: + return codecSpec{}, false + } +} + +func vp9CodecSpec() codecSpec { + return codecSpec{ + mimeType: webrtc.MimeTypeVP9, + fourCC: "VP90", + encoder: "libvpx-vp9", + capability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP9, + ClockRate: 90000, + }, + depacketizer: func() rtp.Depacketizer { return &codecs.VP9Packet{} }, + encodeArgs: []string{ + "-c:v", "libvpx-vp9", + "-deadline", "realtime", + "-cpu-used", "8", + "-row-mt", "1", + "-tile-columns", "2", + "-frame-parallel", "1", + "-lag-in-frames", "0", + "-auto-alt-ref", "0", + "-error-resilient", "1", + "-static-thresh", "0", + "-g", "1", + "-pix_fmt", "yuv420p", + "-crf", "34", + "-b:v", defaultVideoBitrate, + }, + } +} + +func vp8CodecSpec() codecSpec { + return codecSpec{ + mimeType: webrtc.MimeTypeVP8, + fourCC: "VP80", + encoder: "libvpx", + capability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + }, + depacketizer: func() rtp.Depacketizer { return &codecs.VP8Packet{} }, + encodeArgs: []string{ + "-c:v", "libvpx", + "-deadline", "realtime", + "-cpu-used", "8", + "-lag-in-frames", "0", + "-error-resilient", "1", + "-static-thresh", "0", + "-g", "1", + "-pix_fmt", "yuv420p", + "-crf", "24", + "-b:v", defaultVideoBitrate, + }, + } +} + +type ffmpegEncoder struct { + cmd *exec.Cmd + stdin io.WriteCloser + stderr *bytes.Buffer + frames chan []byte + closed atomic.Bool + closeOnce sync.Once + errMu sync.Mutex + err error +} + +func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate string) (*ffmpegEncoder, error) { + if _, err := exec.LookPath("ffmpeg"); err != nil { + return nil, ErrFFmpegUnavailable + } + + args := []string{ + "-loglevel", "error", + "-f", "rawvideo", + "-pix_fmt", "gray", + "-video_size", fmt.Sprintf("%dx%d", width, height), + "-framerate", fmt.Sprintf("%d", fps), + "-i", "pipe:0", + "-an", + } + args = append(args, spec.encodeArgs...) + // Replace default bitrate if provided + for i, arg := range args { + if arg == "-b:v" && i+1 < len(args) && bitrate != "" { + args[i+1] = bitrate + } + } + args = append(args, "-f", "ivf", "pipe:1") + + cmd := exec.Command("ffmpeg", args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("encoder stdin: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("encoder stdout: %w", err) + } + stderr := &bytes.Buffer{} + cmd.Stderr = stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start encoder: %w", err) + } + + enc := &ffmpegEncoder{ + cmd: cmd, + stdin: stdin, + stderr: stderr, + frames: make(chan []byte, 8), + } + + go enc.readIVF(stdout) + return enc, nil +} + +func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) { + if len(frame) != logicalFrameBytes { + return nil, fmt.Errorf("unexpected encoder frame size: %d", len(frame)) + } + if err := e.processErr(); err != nil { + return nil, err + } + + if err := writeAll(e.stdin, frame); err != nil { + return nil, fmt.Errorf("write encoder frame: %w", err) + } + + select { + case sample, ok := <-e.frames: + if !ok { + return nil, e.processErr() + } + return sample, nil + case <-time.After(ffmpegFrameTimeout): + if err := e.processErr(); err != nil { + return nil, err + } + return nil, fmt.Errorf("encoder timeout") + } +} + +func (e *ffmpegEncoder) Close() error { + e.closeOnce.Do(func() { + e.closed.Store(true) + _ = e.stdin.Close() + if e.cmd.Process != nil { + _ = e.cmd.Process.Kill() + } + _ = e.cmd.Wait() + }) + return nil +} + +func (e *ffmpegEncoder) readIVF(stdout io.Reader) { + defer close(e.frames) + + reader, _, err := ivfreader.NewWith(stdout) + if err != nil { + e.setErr(fmt.Errorf("encoder ivf header: %w", err)) + return + } + + for { + frame, _, err := reader.ParseNextFrame() + if err != nil { + if !e.closed.Load() { + e.setErr(fmt.Errorf("encoder ivf read: %w", err)) + } + return + } + + copyFrame := append([]byte(nil), frame...) + if e.closed.Load() { + return + } + e.frames <- copyFrame + } +} + +func (e *ffmpegEncoder) setErr(err error) { + if err == nil { + return + } + e.errMu.Lock() + defer e.errMu.Unlock() + if e.err == nil { + e.err = withStderr(err, e.stderr) + } +} + +func (e *ffmpegEncoder) processErr() error { + e.errMu.Lock() + defer e.errMu.Unlock() + if e.err != nil { + return e.err + } + if e.closed.Load() { + return ErrTransportClosed + } + return nil +} + +type ffmpegDecoder struct { + cmd *exec.Cmd + stdin io.WriteCloser + stderr *bytes.Buffer + frames chan []byte + pts uint64 + closed atomic.Bool + closeOnce sync.Once + errMu sync.Mutex + err error +} + +func newFFmpegDecoder(spec codecSpec, width, height, fps int) (*ffmpegDecoder, error) { + if _, err := exec.LookPath("ffmpeg"); err != nil { + return nil, ErrFFmpegUnavailable + } + + args := []string{ + "-loglevel", "info", + "-flags", "low_delay", + "-vcodec", strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/")), + "-i", "pipe:0", + "-an", + "-vf", fmt.Sprintf("scale=%d:%d:flags=neighbor,format=gray", width, height), + "-pix_fmt", "gray", + "-f", "rawvideo", + "pipe:1", + } + + cmd := exec.Command("ffmpeg", args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("decoder stdin: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("decoder stdout: %w", err) + } + stderr := &bytes.Buffer{} + cmd.Stderr = stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start decoder: %w", err) + } + + dec := &ffmpegDecoder{ + cmd: cmd, + stdin: stdin, + stderr: stderr, + frames: make(chan []byte, 32), + } + + if err := writeIVFHeader(stdin, spec.fourCC, width, height, fps); err != nil { + _ = dec.Close() + return nil, fmt.Errorf("decoder ivf header: %w", err) + } + + go dec.readRawFrames(stdout, width, height) + return dec, nil +} + +func (d *ffmpegDecoder) PushSample(sample []byte) error { + if err := d.processErr(); err != nil { + return err + } + + if err := writeIVFFrame(d.stdin, d.pts, sample); err != nil { + return fmt.Errorf("write decoder frame: %w", err) + } + d.pts++ + return nil +} + +func (d *ffmpegDecoder) PopFrame() ([]byte, error) { + select { + case frame, ok := <-d.frames: + if !ok { + return nil, d.processErr() + } + return frame, nil + case <-time.After(10 * time.Second): + return nil, fmt.Errorf("pop frame timeout") + } +} + +func (d *ffmpegDecoder) Close() error { + d.closeOnce.Do(func() { + d.closed.Store(true) + _ = d.stdin.Close() + if d.cmd.Process != nil { + _ = d.cmd.Process.Kill() + } + _ = d.cmd.Wait() + }) + return nil +} + +func (d *ffmpegDecoder) readRawFrames(stdout io.Reader, width, height int) { + defer close(d.frames) + + logicalFrameBytes := width * height + buf := make([]byte, logicalFrameBytes) + for { + if _, err := io.ReadFull(stdout, buf); err != nil { + if !d.closed.Load() { + d.setErr(fmt.Errorf("decoder raw read: %w", err)) + } + return + } + + copyFrame := append([]byte(nil), buf...) + if d.closed.Load() { + return + } + d.frames <- copyFrame + } +} + +func (d *ffmpegDecoder) setErr(err error) { + if err == nil { + return + } + d.errMu.Lock() + defer d.errMu.Unlock() + if d.err == nil { + d.err = withStderr(err, d.stderr) + } +} + +func (d *ffmpegDecoder) processErr() error { + d.errMu.Lock() + defer d.errMu.Unlock() + if d.err != nil { + return d.err + } + if d.closed.Load() { + return ErrTransportClosed + } + return nil +} + +func withStderr(err error, stderr *bytes.Buffer) error { + if err == nil { + return nil + } + msg := strings.TrimSpace(stderr.String()) + if msg == "" { + return err + } + return fmt.Errorf("%w: %s", err, msg) +} + +func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) error { + header := make([]byte, 32) + copy(header[0:4], []byte("DKIF")) + binary.LittleEndian.PutUint16(header[4:6], 0) + binary.LittleEndian.PutUint16(header[6:8], 32) + copy(header[8:12], []byte(fourCC)) + binary.LittleEndian.PutUint16(header[12:14], uint16(width)) + binary.LittleEndian.PutUint16(header[14:16], uint16(height)) + binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate)) + binary.LittleEndian.PutUint32(header[20:24], 1) + binary.LittleEndian.PutUint32(header[24:28], 0) + binary.LittleEndian.PutUint32(header[28:32], 0) + return writeAll(w, header) +} + +func writeIVFFrame(w io.Writer, pts uint64, frame []byte) error { + header := make([]byte, 12) + binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame))) + binary.LittleEndian.PutUint64(header[4:12], pts) + if err := writeAll(w, header); err != nil { + return err + } + return writeAll(w, frame) +} + +func writeAll(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + return err + } + data = data[n:] + } + return nil +} diff --git a/internal/transport/videochannel/frame.go b/internal/transport/videochannel/frame.go new file mode 100644 index 0000000..cf7f198 --- /dev/null +++ b/internal/transport/videochannel/frame.go @@ -0,0 +1,110 @@ +package videochannel + +import ( + "encoding/binary" + "fmt" +) + +const ( + protocolMagic uint32 = 0x4f565632 // OVV2 + protocolVersion byte = 1 + frameTypeData byte = 1 + frameTypeAck byte = 2 +) + +type transportFrame struct { + typ byte + seq uint32 + crc uint32 + totalLen uint32 + fragIdx uint16 + fragTotal uint16 + payload []byte +} + +type inboundMessage struct { + totalLen uint32 + crc uint32 + frags [][]byte + remain int +} + +func fragmentPayload(data []byte, maxSize int) [][]byte { + if len(data) == 0 { + return [][]byte{{}} + } + + out := make([][]byte, 0, (len(data)+maxSize-1)/maxSize) + for start := 0; start < len(data); start += maxSize { + end := start + maxSize + if end > len(data) { + end = len(data) + } + + chunk := make([]byte, end-start) + copy(chunk, data[start:end]) + out = append(out, chunk) + } + + return out +} + +func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload []byte) []byte { + out := make([]byte, 22+len(payload)) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeData + binary.BigEndian.PutUint32(out[6:10], seq) + binary.BigEndian.PutUint32(out[10:14], crc) + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) + binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) + copy(out[22:], payload) + return out +} + +func encodeAckFrame(seq, crc uint32) []byte { + out := make([]byte, 14) + binary.BigEndian.PutUint32(out[0:4], protocolMagic) + out[4] = protocolVersion + out[5] = frameTypeAck + binary.BigEndian.PutUint32(out[6:10], seq) + binary.BigEndian.PutUint32(out[10:14], crc) + return out +} + +func decodeTransportFrame(data []byte) (transportFrame, error) { + if len(data) < 6 { + return transportFrame{}, fmt.Errorf("frame too short") + } + if binary.BigEndian.Uint32(data[0:4]) != protocolMagic { + return transportFrame{}, fmt.Errorf("unexpected frame magic") + } + if data[4] != protocolVersion { + return transportFrame{}, fmt.Errorf("unexpected frame version") + } + + frame := transportFrame{typ: data[5]} + switch frame.typ { + case frameTypeAck: + if len(data) < 14 { + return transportFrame{}, fmt.Errorf("ack too short") + } + frame.seq = binary.BigEndian.Uint32(data[6:10]) + frame.crc = binary.BigEndian.Uint32(data[10:14]) + return frame, nil + case frameTypeData: + if len(data) < 22 { + return transportFrame{}, fmt.Errorf("data too short") + } + frame.seq = binary.BigEndian.Uint32(data[6:10]) + frame.crc = binary.BigEndian.Uint32(data[10:14]) + frame.totalLen = binary.BigEndian.Uint32(data[14:18]) + frame.fragIdx = binary.BigEndian.Uint16(data[18:20]) + frame.fragTotal = binary.BigEndian.Uint16(data[20:22]) + frame.payload = append([]byte(nil), data[22:]...) + return frame, nil + default: + return transportFrame{}, fmt.Errorf("unexpected frame type") + } +} diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go new file mode 100644 index 0000000..aa71457 --- /dev/null +++ b/internal/transport/videochannel/transport.go @@ -0,0 +1,478 @@ +// Package videochannel provides a byte transport over a visual video stream. +package videochannel + +import ( + "context" + "errors" + "fmt" + "hash/crc32" + "sync" + "sync/atomic" + "time" + + "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/logger" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media" + "github.com/pion/webrtc/v4/pkg/media/samplebuilder" +) + +const ( + defaultMaxPayloadSize = 16 * 1024 + defaultFragmentSize = 256 + defaultAckTimeout = 1 * time.Second + defaultFrameInterval = 40 * time.Millisecond + defaultConnectTimeout = 30 * time.Second + maxSendAttempts = 20 + sampleBuilderMaxLate = 128 +) + +var ( + // ErrVideoTrackUnsupported is returned when a carrier cannot expose video tracks. + ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") + // ErrAckTimeout is returned when the peer does not acknowledge a payload in time. + ErrAckTimeout = errors.New("videochannel ack timeout") + // ErrTransportClosed is returned when operations are attempted on a closed transport. + ErrTransportClosed = errors.New("videochannel transport closed") +) + +type streamTransport struct { + stream carrier.VideoTrack + track *webrtc.TrackLocalStaticSample + codec codecSpec + encoder *ffmpegEncoder + onData func([]byte) + outbound chan []byte + outboundAck chan []byte + closeCh chan struct{} + writerDone chan struct{} + nextSeq atomic.Uint32 + closed atomic.Bool + writerUp atomic.Bool + sendMu sync.Mutex + startWriter sync.Once + ackMu sync.Mutex + ackWaiters map[uint32]chan uint32 + recvMu sync.Mutex + inbound map[uint32]*inboundMessage + delivered map[uint32]uint32 + videoW int + videoH int + videoFPS int + videoBitrate string +} + +// New creates a visual videochannel transport backed by a carrier-specific provider. +func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { + session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: nil, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("create provider transport: %w", err) + } + + videoCapable, ok := session.(carrier.VideoTrackCapable) + if !ok { + return nil, ErrVideoTrackUnsupported + } + + stream, err := videoCapable.OpenVideoTrack() + if err != nil { + return nil, fmt.Errorf("open video track: %w", err) + } + + codec := codecSpecForCarrier(cfg.Carrier) + track, err := webrtc.NewTrackLocalStaticSample(codec.capability, "videochannel", "olcrtc") + if err != nil { + return nil, fmt.Errorf("create local video track: %w", err) + } + + tr := &streamTransport{ + stream: stream, + track: track, + codec: codec, + onData: cfg.OnData, + outbound: make(chan []byte, 256), + outboundAck: make(chan []byte, 64), + closeCh: make(chan struct{}), + writerDone: make(chan struct{}), + ackWaiters: make(map[uint32]chan uint32), + inbound: make(map[uint32]*inboundMessage), + delivered: make(map[uint32]uint32), + videoW: cfg.VideoWidth, + videoH: cfg.VideoHeight, + videoFPS: cfg.VideoFPS, + videoBitrate: cfg.VideoBitrate, + } + + if err := stream.AddTrack(track); err != nil { + return nil, fmt.Errorf("attach local video track: %w", err) + } + stream.SetTrackHandler(tr.handleRemoteTrack) + + return tr, nil +} + +// Connect starts the transport connection. +func (p *streamTransport) Connect(ctx context.Context) error { + connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout) + defer cancel() + + encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate) + if err != nil { + return err + } + + if err := p.stream.Connect(connectCtx); err != nil { + _ = encoder.Close() + return err + } + + p.encoder = encoder + p.startWriter.Do(func() { + p.writerUp.Store(true) + go p.writerLoop() + }) + + return nil +} + +// Send transmits data through the transport. +func (p *streamTransport) Send(data []byte) error { + if p.closed.Load() { + return ErrTransportClosed + } + + p.sendMu.Lock() + defer p.sendMu.Unlock() + + seq := p.nextSeq.Add(1) + crc := crc32.ChecksumIEEE(data) + fragments := fragmentPayload(data, defaultFragmentSize) + waiter := make(chan uint32, 1) + + p.ackMu.Lock() + p.ackWaiters[seq] = waiter + p.ackMu.Unlock() + defer func() { + p.ackMu.Lock() + delete(p.ackWaiters, seq) + p.ackMu.Unlock() + }() + + for attempt := 0; attempt < maxSendAttempts; attempt++ { + for idx, fragment := range fragments { + frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment) + if err := p.enqueueFrame(frame, false); err != nil { + return err + } + } + + timer := time.NewTimer(defaultAckTimeout) + select { + case ackCRC := <-waiter: + timer.Stop() + if ackCRC == crc { + return nil + } + case <-timer.C: + case <-p.closeCh: + timer.Stop() + return ErrTransportClosed + } + } + + return ErrAckTimeout +} + +// Close terminates the transport. +func (p *streamTransport) Close() error { + if p.closed.CompareAndSwap(false, true) { + close(p.closeCh) + if p.encoder != nil { + _ = p.encoder.Close() + } + if p.writerUp.Load() { + <-p.writerDone + } + return p.stream.Close() + } + return nil +} + +// SetReconnectCallback registers reconnect handling. +func (p *streamTransport) SetReconnectCallback(cb func()) { + p.stream.SetReconnectCallback(cb) +} + +// SetShouldReconnect configures reconnect policy. +func (p *streamTransport) SetShouldReconnect(fn func() bool) { + p.stream.SetShouldReconnect(fn) +} + +// SetEndedCallback registers end-of-session handling. +func (p *streamTransport) SetEndedCallback(cb func(string)) { + p.stream.SetEndedCallback(cb) +} + +// WatchConnection monitors connection lifecycle. +func (p *streamTransport) WatchConnection(ctx context.Context) { + p.stream.WatchConnection(ctx) +} + +// CanSend reports whether transport is ready for sending. +func (p *streamTransport) CanSend() bool { + return !p.closed.Load() && p.stream.CanSend() +} + +// Features describes the current videochannel transport semantics. +func (p *streamTransport) Features() transport.Features { + return transport.Features{ + Reliable: true, + Ordered: true, + MessageOriented: true, + MaxPayloadSize: defaultMaxPayloadSize, + } +} + +func (p *streamTransport) writerLoop() { + defer close(p.writerDone) + defer func() { + if p.encoder != nil { + _ = p.encoder.Close() + } + }() + + ticker := time.NewTicker(time.Second / time.Duration(p.videoFPS)) + defer ticker.Stop() + + for { + select { + case <-p.closeCh: + return + case <-ticker.C: + payload, ok := p.nextOutboundFrame() + if !ok { + return + } + + rawFrame, err := renderVisualFrame(payload, p.videoW, p.videoH) + if err != nil { + logger.Debugf("videochannel render error: %v", err) + continue + } + + sample, err := p.encoder.EncodeFrame(rawFrame) + if err != nil { + logger.Warnf("videochannel encoder error: %v", err) + continue + } + + _ = p.track.WriteSample(media.Sample{ + Data: sample, + Duration: time.Second / time.Duration(p.videoFPS), + }) + } + } +} + +func (p *streamTransport) nextOutboundFrame() ([]byte, bool) { + select { + case <-p.closeCh: + return nil, false + case payload := <-p.outboundAck: + return payload, true + default: + } + + select { + case <-p.closeCh: + return nil, false + case payload := <-p.outboundAck: + return payload, true + case payload := <-p.outbound: + return payload, true + default: + return nil, true + } +} + +func (p *streamTransport) enqueueFrame(frame []byte, priority bool) error { + if p.closed.Load() { + return ErrTransportClosed + } + + ch := p.outbound + if priority { + ch = p.outboundAck + } + + select { + case <-p.closeCh: + return ErrTransportClosed + case ch <- frame: + return nil + } +} + +func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + codec, ok := codecSpecForMime(track.Codec().MimeType) + if !ok { + logger.Warnf("videochannel unsupported remote codec: %s", track.Codec().MimeType) + return + } + + decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS) + if err != nil { + logger.Warnf("videochannel decoder init failed: %v", err) + return + } + + go func() { + defer func() { _ = decoder.Close() }() + for { + frame, err := decoder.PopFrame() + if err != nil { + if !errors.Is(err, ErrTransportClosed) { + logger.Warnf("videochannel decoder pop error: %v", err) + } + return + } + p.handleFrame(frame) + } + }() + + go func() { + sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate) + for { + packet, _, err := track.ReadRTP() + if err != nil { + sb.Flush() + return + } + + sb.Push(packet) + for sample := sb.Pop(); sample != nil; sample = sb.Pop() { + if err := decoder.PushSample(sample.Data); err != nil { + logger.Warnf("videochannel decoder push error: %v", err) + return + } + } + } + }() +} + +func (p *streamTransport) handleFrame(frame []byte) { + payload, err := extractVisualPayload(frame, p.videoW, p.videoH) + if err != nil || len(payload) == 0 { + if err != nil { + logger.Debugf("videochannel extract visual payload error: %v", err) + } + return + } + + logger.Debugf("videochannel extracted visual payload: len=%d", len(payload)) + + decoded, err := decodeTransportFrame(payload) + if err != nil { + logger.Debugf("videochannel decode transport frame error: %v", err) + return + } + + logger.Debugf("videochannel transport frame: type=%d seq=%d crc=%x", decoded.typ, decoded.seq, decoded.crc) + + switch decoded.typ { + case frameTypeAck: + p.resolveAck(decoded.seq, decoded.crc) + case frameTypeData: + p.handleInboundFrame(decoded) + } +} + +func (p *streamTransport) handleInboundFrame(frame transportFrame) { + p.recvMu.Lock() + if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { + p.recvMu.Unlock() + p.sendAck(frame.seq, frame.crc) + return + } + + msg, ok := p.inbound[frame.seq] + if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { + msg = &inboundMessage{ + totalLen: frame.totalLen, + crc: frame.crc, + frags: make([][]byte, frame.fragTotal), + remain: int(frame.fragTotal), + } + p.inbound[frame.seq] = msg + } + + if int(frame.fragIdx) >= len(msg.frags) { + p.recvMu.Unlock() + return + } + + if msg.frags[frame.fragIdx] == nil { + chunk := make([]byte, len(frame.payload)) + copy(chunk, frame.payload) + msg.frags[frame.fragIdx] = chunk + msg.remain-- + } + + if msg.remain > 0 { + p.recvMu.Unlock() + return + } + + delete(p.inbound, frame.seq) + data := make([]byte, 0, msg.totalLen) + for _, frag := range msg.frags { + data = append(data, frag...) + } + + if uint32(len(data)) > msg.totalLen { + data = data[:msg.totalLen] + } + + if crc32.ChecksumIEEE(data) != msg.crc { + p.recvMu.Unlock() + return + } + + if len(p.delivered) > 256 { + p.delivered = make(map[uint32]uint32) + } + p.delivered[frame.seq] = msg.crc + p.recvMu.Unlock() + + if p.onData != nil { + p.onData(data) + } + p.sendAck(frame.seq, frame.crc) +} + +func (p *streamTransport) sendAck(seq, crc uint32) { + _ = p.enqueueFrame(encodeAckFrame(seq, crc), true) +} + +func (p *streamTransport) resolveAck(seq, crc uint32) { + p.ackMu.Lock() + waiter := p.ackWaiters[seq] + p.ackMu.Unlock() + + if waiter == nil { + return + } + + select { + case waiter <- crc: + default: + } +} diff --git a/internal/transport/videochannel/transport_test.go b/internal/transport/videochannel/transport_test.go new file mode 100644 index 0000000..d7fd114 --- /dev/null +++ b/internal/transport/videochannel/transport_test.go @@ -0,0 +1,51 @@ +package videochannel + +import ( + "bytes" + "testing" +) + +func TestVisualRoundTrip(t *testing.T) { + payload := []byte("hello over visual videochannel") + frame, err := renderVisualFrame(payload) + if err != nil { + t.Fatalf("renderVisualFrame failed: %v", err) + } + + got, err := extractVisualPayload(frame) + if err != nil { + t.Fatalf("extractVisualPayload failed: %v", err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("payload mismatch: got=%q want=%q", got, payload) + } +} + +func TestIdleFrameIgnored(t *testing.T) { + frame, err := renderVisualFrame(nil) + if err != nil { + t.Fatalf("renderVisualFrame failed: %v", err) + } + + got, err := extractVisualPayload(frame) + if err == nil && len(got) != 0 { + t.Fatalf("expected idle frame to be ignored, got=%q", got) + } +} + +func TestTransportFrameRoundTrip(t *testing.T) { + encoded := encodeDataFrame(42, 0xdeadbeef, 1024, 1, 3, []byte("chunk")) + decoded, err := decodeTransportFrame(encoded) + if err != nil { + t.Fatalf("decodeTransportFrame failed: %v", err) + } + if decoded.typ != frameTypeData || decoded.seq != 42 || decoded.crc != 0xdeadbeef { + t.Fatalf("unexpected frame header: %+v", decoded) + } + if decoded.totalLen != 1024 || decoded.fragIdx != 1 || decoded.fragTotal != 3 { + t.Fatalf("unexpected fragmentation fields: %+v", decoded) + } + if !bytes.Equal(decoded.payload, []byte("chunk")) { + t.Fatalf("payload mismatch: got=%q", decoded.payload) + } +} diff --git a/internal/transport/videochannel/visual.go b/internal/transport/videochannel/visual.go new file mode 100644 index 0000000..15e52aa --- /dev/null +++ b/internal/transport/videochannel/visual.go @@ -0,0 +1,112 @@ +package videochannel + +import ( + "encoding/base64" + "fmt" + "image" + "strings" + + barcodedm "github.com/boombuler/barcode/datamatrix" + "github.com/makiuchi-d/gozxing" + zxingdm "github.com/makiuchi-d/gozxing/datamatrix" +) + +const ( + quietZone = 10 +) + +func renderVisualFrame(payload []byte, width, height int) ([]byte, error) { + logicalFrameBytes := width * height + frame := make([]byte, logicalFrameBytes) + for i := range frame { + frame[i] = 0xff // White background + } + + if len(payload) == 0 { + return frame, nil + } + + encoded := base64.StdEncoding.EncodeToString(payload) + dm, err := barcodedm.Encode(encoded) + if err != nil { + return nil, fmt.Errorf("datamatrix encode: %w", err) + } + + // Use strict integer scaling to keep edges sharp + bounds := dm.Bounds() + dmW := bounds.Dx() + dmH := bounds.Dy() + + scaleW := (width - (quietZone * 2)) / dmW + scaleH := (height - (quietZone * 2)) / dmH + scale := scaleW + if scaleH < scale { + scale = scaleH + } + if scale < 1 { + scale = 1 + } + + totalW := dmW * scale + totalH := dmH * scale + offsetX := (width - totalW) / 2 + offsetY := (height - totalH) / 2 + + for y := 0; y < dmH; y++ { + for x := 0; x < dmW; x++ { + r, _, _, _ := dm.At(bounds.Min.X+x, bounds.Min.Y+y).RGBA() + if r < 0x8000 { + // Fill scale x scale block + for sy := 0; sy < scale; sy++ { + for sx := 0; sx < scale; sx++ { + pixelX := offsetX + (x * scale) + sx + pixelY := offsetY + (y * scale) + sy + if pixelX < width && pixelY < height { + frame[pixelY*width+pixelX] = 0x00 + } + } + } + } + } + } + + return frame, nil +} + +func extractVisualPayload(frame []byte, width, height int) ([]byte, error) { + logicalFrameBytes := width * height + if len(frame) != logicalFrameBytes { + return nil, fmt.Errorf("unexpected frame size: %d (expected %dx%d=%d)", len(frame), width, height, logicalFrameBytes) + } + + img := image.NewGray(image.Rect(0, 0, width, height)) + copy(img.Pix, frame) + + source := gozxing.NewLuminanceSourceFromImage(img) + // HybridBinarizer is good for noisy images + binarizer := gozxing.NewHybridBinarizer(source) + bmp, err := gozxing.NewBinaryBitmap(binarizer) + if err != nil { + return nil, fmt.Errorf("bitmap: %w", err) + } + + reader := zxingdm.NewDataMatrixReader() + hints := make(map[gozxing.DecodeHintType]interface{}) + hints[gozxing.DecodeHintType_TRY_HARDER] = true + hints[gozxing.DecodeHintType_PURE_BARCODE] = true + + result, err := reader.Decode(bmp, hints) + if err != nil { + if strings.Contains(err.Error(), "NotFoundException") { + return nil, nil + } + return nil, fmt.Errorf("decode: %w", err) + } + + decoded, err := base64.StdEncoding.DecodeString(result.GetText()) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + + return decoded, nil +}