From dd606ddfb2a78a6952682f9b49cf4ebbd577e30a Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Sun, 3 May 2026 06:10:48 +0300 Subject: [PATCH] fix: fix all golangci errors --- cmd/olcrtc/main.go | 49 ++-- internal/app/session/session.go | 234 +++++++++------ internal/carrier/bytestream.go | 64 ++-- internal/carrier/carrier.go | 1 + internal/client/client.go | 167 ++++++----- internal/link/direct/direct.go | 52 ++-- internal/link/link.go | 1 + internal/muxconn/conn.go | 5 +- internal/provider/jazz/peer.go | 125 +++++--- internal/provider/provider.go | 2 - internal/provider/telemost/peer.go | 116 +++++--- internal/server/server.go | 30 +- internal/transport/datachannel/transport.go | 15 +- internal/transport/seichannel/h264.go | 40 ++- internal/transport/seichannel/transport.go | 95 +++--- .../transport/seichannel/transport_test.go | 5 +- internal/transport/transport.go | 1 + internal/transport/videochannel/ffmpeg.go | 216 +++++++------- internal/transport/videochannel/frame.go | 35 ++- internal/transport/videochannel/transport.go | 275 ++++++++++-------- internal/transport/videochannel/visual.go | 28 +- internal/transport/vp8channel/kcp.go | 9 +- internal/transport/vp8channel/kcpconn.go | 32 +- internal/transport/vp8channel/transport.go | 130 +++++---- .../transport/vp8channel/transport_test.go | 105 +++---- 25 files changed, 1072 insertions(+), 760 deletions(-) diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index a91c40e..03cec2a 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -3,6 +3,7 @@ package main import ( "context" + "errors" "flag" "fmt" "os" @@ -16,6 +17,9 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/names" ) +// ErrDataDirRequired is returned when no data directory is specified. +var ErrDataDirRequired = errors.New("data directory required (use -data data)") + type config struct { mode string link string @@ -59,11 +63,11 @@ func run() error { configureLogging(cfg.debug) if err := session.Validate(toSessionConfig(cfg)); err != nil { - return err + return fmt.Errorf("validate config: %w", err) } if cfg.dataDir == "" { - return fmt.Errorf("data directory required (use -data data)") + return ErrDataDirRequired } dataDir, err := resolveDataDir(cfg.dataDir) @@ -119,10 +123,13 @@ func parseFlags() config { flag.StringVar(&cfg.videoBitrate, "video-bitrate", "", "Video bitrate (videochannel only)") flag.StringVar(&cfg.videoHW, "video-hw", "", "Hardware acceleration (none, nvenc)") flag.IntVar(&cfg.videoQRSize, "video-qr-size", 0, "Video QR code fragment size (videochannel only)") - flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low", "QR error correction: low (7%), medium (15%), high (25%), highest (30%)") + flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low", + "QR error correction: low (7%), medium (15%), high (25%), highest (30%)") flag.StringVar(&cfg.videoCodec, "video-codec", "qrcode", "Visual codec: qrcode or tile") - flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0, "Tile module size in pixels 1..270 (videochannel tile only, default 4)") - flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0, "Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)") + flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0, + "Tile module size in pixels 1..270 (videochannel tile only, default 4)") + flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0, + "Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)") flag.IntVar(&cfg.vp8FPS, "vp8-fps", 0, "VP8 frames per second (vp8channel only, default 25)") flag.IntVar(&cfg.vp8BatchSize, "vp8-batch", 0, "VP8 frames per tick (vp8channel only, default 1)") flag.Parse() @@ -161,22 +168,22 @@ func loadNames(dataDir string) error { func toSessionConfig(cfg config) session.Config { return session.Config{ - Mode: cfg.mode, - Link: cfg.link, - Transport: cfg.transport, - Carrier: firstNonEmpty(cfg.carrier, cfg.provider), - RoomID: cfg.roomID, - KeyHex: cfg.keyHex, - SOCKSHost: cfg.socksHost, - SOCKSPort: cfg.socksPort, - DNSServer: cfg.dnsServer, - SOCKSProxyAddr: cfg.socksProxyAddr, - SOCKSProxyPort: cfg.socksProxyPort, - VideoWidth: cfg.videoWidth, - VideoHeight: cfg.videoHeight, - VideoFPS: cfg.videoFPS, - VideoBitrate: cfg.videoBitrate, - VideoHW: cfg.videoHW, + Mode: cfg.mode, + Link: cfg.link, + Transport: cfg.transport, + Carrier: firstNonEmpty(cfg.carrier, cfg.provider), + RoomID: cfg.roomID, + KeyHex: cfg.keyHex, + SOCKSHost: cfg.socksHost, + SOCKSPort: cfg.socksPort, + DNSServer: cfg.dnsServer, + SOCKSProxyAddr: cfg.socksProxyAddr, + SOCKSProxyPort: cfg.socksProxyPort, + VideoWidth: cfg.videoWidth, + VideoHeight: cfg.videoHeight, + VideoFPS: cfg.videoFPS, + VideoBitrate: cfg.videoBitrate, + VideoHW: cfg.videoHW, VideoQRSize: cfg.videoQRSize, VideoQRRecovery: cfg.videoQRRecovery, VideoCodec: cfg.videoCodec, diff --git a/internal/app/session/session.go b/internal/app/session/session.go index e63af56..3ea732e 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -19,13 +19,19 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/transport/vp8channel" ) +const ( + modeSRV = "srv" + modeCNC = "cnc" +) + var ( // ErrRoomIDRequired indicates that a room id is required for the selected carrier. ErrRoomIDRequired = errors.New("room ID required (use -id )") // ErrModeRequired indicates that mode is not one of the supported values. ErrModeRequired = errors.New("mode required (use -mode srv or -mode cnc)") // ErrCarrierRequired indicates that no carrier was selected. - ErrCarrierRequired = errors.New("carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)") + ErrCarrierRequired = errors.New( + "carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)") // ErrUnsupportedCarrier indicates that carrier is not registered. ErrUnsupportedCarrier = errors.New("unsupported carrier") // ErrUnsupportedLink indicates that link is not registered. @@ -36,26 +42,40 @@ var ( // ErrLinkRequired indicates that link is not provided. ErrLinkRequired = errors.New("link required (use -link direct)") // ErrTransportRequired indicates that transport is not provided. - ErrTransportRequired = errors.New("transport required (use -transport datachannel, -transport videochannel, -transport seichannel or -transport vp8channel)") + ErrTransportRequired = errors.New( + "transport required (use -transport datachannel, -transport videochannel, " + + "-transport seichannel or -transport vp8channel)") // ErrKeyRequired indicates that encryption key is not provided. ErrKeyRequired = errors.New("key required (use -key )") // ErrDNSServerRequired indicates that dns server is not provided. ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)") - // Videochannel errors - ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)") - ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)") - ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)") - ErrVideoBitrateRequired = errors.New("video bitrate required for videochannel (use -video-bitrate)") - ErrVideoHWRequired = errors.New("video hardware acceleration required for videochannel (use -video-hw none/nvenc)") - ErrVideoCodecInvalid = errors.New("invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)") + // ErrVideoWidthRequired indicates that video width is required for videochannel. + ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)") + // ErrVideoHeightRequired indicates that video height is required for videochannel. + ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)") + // ErrVideoFPSRequired indicates that video fps is required for videochannel. + ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)") + // ErrVideoBitrateRequired indicates that video bitrate is required for videochannel. + ErrVideoBitrateRequired = errors.New( + "video bitrate required for videochannel (use -video-bitrate)") + // ErrVideoHWRequired indicates that video hardware acceleration is required. + ErrVideoHWRequired = errors.New( + "video hardware acceleration required for videochannel (use -video-hw none/nvenc)") + // ErrVideoCodecInvalid indicates that the video codec is not valid. + ErrVideoCodecInvalid = errors.New( + "invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)") + // ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions. + ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080") - // VP8channel errors - ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)") + // ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel. + ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)") + // ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel. ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)") - // CNC errors + // ErrSOCKSHostRequired indicates that socks host is required for cnc mode. ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)") + // ErrSOCKSPortRequired indicates that socks port is required for cnc mode. ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)") ) @@ -98,115 +118,143 @@ func RegisterDefaults() { // Validate verifies that the runtime config refers to registered components and all required fields are present. func Validate(cfg Config) error { - availableCarriers := carrier.Available() - validCarrier := false - for _, c := range availableCarriers { - if cfg.Carrier == c { - validCarrier = true - break - } + if err := validateMode(cfg); err != nil { + return err } - - availableTransports := transport.Available() - validTransport := false - for _, t := range availableTransports { - if cfg.Transport == t { - validTransport = true - break - } + if err := validateCarrier(cfg); err != nil { + return err } - - availableLinks := link.Available() - validLink := false - for _, l := range availableLinks { - if cfg.Link == l { - validLink = true - break - } + if err := validateLink(cfg); err != nil { + return err } + if err := validateTransportRegistration(cfg); err != nil { + return err + } + if err := validateCommon(cfg); err != nil { + return err + } + if err := validateTransportConfig(cfg); err != nil { + return err + } + return validateModeConfig(cfg) +} - if cfg.Mode == "" { - return ErrModeRequired - } - if cfg.Mode != "srv" && cfg.Mode != "cnc" { +func validateMode(cfg Config) error { + if cfg.Mode == "" || (cfg.Mode != modeSRV && cfg.Mode != modeCNC) { return ErrModeRequired } + return nil +} +func validateCarrier(cfg Config) error { if cfg.Carrier == "" { return ErrCarrierRequired } - if !validCarrier { - return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, availableCarriers) + for _, c := range carrier.Available() { + if cfg.Carrier == c { + return nil + } } + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, carrier.Available()) +} +func validateLink(cfg Config) error { if cfg.Link == "" { return ErrLinkRequired } - if !validLink { - return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, availableLinks) + for _, l := range link.Available() { + if cfg.Link == l { + return nil + } } + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, link.Available()) +} +func validateTransportRegistration(cfg Config) error { if cfg.Transport == "" { return ErrTransportRequired } - if !validTransport { - return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, availableTransports) + for _, t := range transport.Available() { + if cfg.Transport == t { + return nil + } } + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, transport.Available()) +} +func validateCommon(cfg Config) error { if cfg.RoomID == "" && cfg.Carrier != "jazz" { return ErrRoomIDRequired } - if cfg.KeyHex == "" { return ErrKeyRequired } - if cfg.DNSServer == "" { return ErrDNSServerRequired } + return nil +} - if cfg.Transport == "videochannel" { - if cfg.VideoWidth == 0 { - return ErrVideoWidthRequired - } - if cfg.VideoHeight == 0 { - return ErrVideoHeightRequired - } - if cfg.VideoFPS == 0 { - return ErrVideoFPSRequired - } - if cfg.VideoBitrate == "" { - return ErrVideoBitrateRequired - } - if cfg.VideoHW == "" { - return ErrVideoHWRequired - } - if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" { - return ErrVideoCodecInvalid - } - if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) { - return errors.New("tile codec requires -video-w 1080 -video-h 1080") - } +func validateTransportConfig(cfg Config) error { + switch cfg.Transport { + case "videochannel": + return validateVideoChannel(cfg) + case "vp8channel": + return validateVP8Channel(cfg) + default: + return nil } +} - if cfg.Transport == "vp8channel" { - if cfg.VP8FPS == 0 { - return ErrVP8FPSRequired - } - if cfg.VP8BatchSize == 0 { - return ErrVP8BatchSizeRequired - } +func validateVideoCodec(cfg Config) error { + if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" { + return ErrVideoCodecInvalid } - - if cfg.Mode == "cnc" { - if cfg.SOCKSHost == "" { - return ErrSOCKSHostRequired - } - if cfg.SOCKSPort == 0 { - return ErrSOCKSPortRequired - } + if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) { + return ErrTileCodecDimensions } + return nil +} +func validateVideoChannel(cfg Config) error { + if cfg.VideoWidth == 0 { + return ErrVideoWidthRequired + } + if cfg.VideoHeight == 0 { + return ErrVideoHeightRequired + } + if cfg.VideoFPS == 0 { + return ErrVideoFPSRequired + } + if cfg.VideoBitrate == "" { + return ErrVideoBitrateRequired + } + if cfg.VideoHW == "" { + return ErrVideoHWRequired + } + return validateVideoCodec(cfg) +} + +func validateVP8Channel(cfg Config) error { + if cfg.VP8FPS == 0 { + return ErrVP8FPSRequired + } + if cfg.VP8BatchSize == 0 { + return ErrVP8BatchSizeRequired + } + return nil +} + +func validateModeConfig(cfg Config) error { + if cfg.Mode != modeCNC { + return nil + } + if cfg.SOCKSHost == "" { + return ErrSOCKSHostRequired + } + if cfg.SOCKSPort == 0 { + return ErrSOCKSPortRequired + } return nil } @@ -215,8 +263,8 @@ func Run(ctx context.Context, cfg Config) error { roomURL := buildRoomURL(cfg.Carrier, cfg.RoomID) switch cfg.Mode { - case "srv": - return server.Run( + case modeSRV: + if err := server.Run( ctx, cfg.Link, cfg.Transport, @@ -238,9 +286,12 @@ func Run(ctx context.Context, cfg Config) error { cfg.VideoTileRS, cfg.VP8FPS, cfg.VP8BatchSize, - ) - case "cnc": - return client.Run( + ); err != nil { + return fmt.Errorf("server: %w", err) + } + return nil + case modeCNC: + if err := client.Run( ctx, cfg.Link, cfg.Transport, @@ -263,7 +314,10 @@ func Run(ctx context.Context, cfg Config) error { cfg.VideoTileRS, cfg.VP8FPS, cfg.VP8BatchSize, - ) + ); err != nil { + return fmt.Errorf("client: %w", err) + } + return nil default: return ErrModeRequired } diff --git a/internal/carrier/bytestream.go b/internal/carrier/bytestream.go index ebfe5dc..e394159 100644 --- a/internal/carrier/bytestream.go +++ b/internal/carrier/bytestream.go @@ -2,6 +2,7 @@ package carrier import ( "context" + "fmt" "github.com/openlibrecommunity/olcrtc/internal/provider" "github.com/pion/webrtc/v4" @@ -32,6 +33,11 @@ type VideoTrack interface { SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) } +type videoTrackProvider interface { + provider.Provider + provider.VideoTrackCapable +} + type legacySession struct { provider provider.Provider } @@ -39,7 +45,7 @@ type legacySession struct { // Capabilities reports the transport primitives supported by the legacy carrier. func (s *legacySession) Capabilities() Capabilities { caps := Capabilities{ByteStream: true} - _, caps.VideoTrack = s.provider.(provider.VideoTrackCapable) + _, caps.VideoTrack = s.provider.(videoTrackProvider) return caps } @@ -50,20 +56,35 @@ func (s *legacySession) OpenByteStream() (ByteStream, error) { // OpenVideoTrack adapts a legacy provider to the generic video track capability. func (s *legacySession) OpenVideoTrack() (VideoTrack, error) { - publisher, ok := s.provider.(provider.VideoTrackCapable) + vtp, ok := s.provider.(videoTrackProvider) if !ok { return nil, ErrVideoTrackUnsupported } - return &legacyVideoTrack{provider: publisher}, nil + return &legacyVideoTrack{provider: vtp}, nil } type legacyByteStream struct { provider provider.Provider } -func (p *legacyByteStream) Connect(ctx context.Context) error { return p.provider.Connect(ctx) } -func (p *legacyByteStream) Send(data []byte) error { return p.provider.Send(data) } -func (p *legacyByteStream) Close() error { return p.provider.Close() } +func (p *legacyByteStream) Connect(ctx context.Context) error { + if err := p.provider.Connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + return nil +} +func (p *legacyByteStream) Send(data []byte) error { + if err := p.provider.Send(data); err != nil { + return fmt.Errorf("send: %w", err) + } + return nil +} +func (p *legacyByteStream) Close() error { + if err := p.provider.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil +} func (p *legacyByteStream) SetReconnectCallback(cb func()) { p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) { @@ -81,31 +102,38 @@ func (p *legacyByteStream) WatchConnection(ctx context.Context) { func (p *legacyByteStream) CanSend() bool { return p.provider.CanSend() } type legacyVideoTrack struct { - provider provider.VideoTrackCapable + provider videoTrackProvider } func (v *legacyVideoTrack) Connect(ctx context.Context) error { - return v.provider.(provider.Provider).Connect(ctx) + if err := v.provider.Connect(ctx); err != nil { + return fmt.Errorf("connect: %w", err) + } + return nil } -func (v *legacyVideoTrack) Close() error { return v.provider.(provider.Provider).Close() } -func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) { - v.provider.(provider.Provider).SetShouldReconnect(fn) -} -func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) { - v.provider.(provider.Provider).SetEndedCallback(cb) +func (v *legacyVideoTrack) Close() error { + if err := v.provider.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil } +func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) { v.provider.SetShouldReconnect(fn) } +func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) { v.provider.SetEndedCallback(cb) } func (v *legacyVideoTrack) WatchConnection(ctx context.Context) { - v.provider.(provider.Provider).WatchConnection(ctx) + v.provider.WatchConnection(ctx) } -func (v *legacyVideoTrack) CanSend() bool { return v.provider.(provider.Provider).CanSend() } +func (v *legacyVideoTrack) CanSend() bool { return v.provider.CanSend() } func (v *legacyVideoTrack) AddTrack(track webrtc.TrackLocal) error { - return v.provider.AddVideoTrack(track) + if err := v.provider.AddVideoTrack(track); err != nil { + return fmt.Errorf("add track: %w", err) + } + return nil } func (v *legacyVideoTrack) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { v.provider.SetVideoTrackHandler(cb) } func (v *legacyVideoTrack) SetReconnectCallback(cb func()) { - v.provider.(provider.Provider).SetReconnectCallback(func(_ *webrtc.DataChannel) { + v.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) { if cb != nil { cb() } diff --git a/internal/carrier/carrier.go b/internal/carrier/carrier.go index 98ed15e..ab98945 100644 --- a/internal/carrier/carrier.go +++ b/internal/carrier/carrier.go @@ -51,6 +51,7 @@ type Config struct { // Factory creates a new carrier session. type Factory func(ctx context.Context, cfg Config) (Session, error) +//nolint:gochecknoglobals var registry = make(map[string]Factory) // Register adds a carrier factory to the registry. diff --git a/internal/client/client.go b/internal/client/client.go index ef8ab6e..c87049d 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -26,15 +26,25 @@ var ( ErrConnectFailed = errors.New("tunnel connection failed") // ErrProxyAuth is returned when SOCKS proxy authentication fails. ErrProxyAuth = errors.New("SOCKS proxy auth failed") + // ErrKeySize is returned when the encryption key is not 32 bytes. + ErrKeySize = errors.New("key must be 32 bytes") + // ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5. + ErrInvalidSOCKSVersion = errors.New("invalid socks version") + // ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands. + ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command") + // ErrUnsupportedAddressType is returned for unsupported SOCKS address types. + ErrUnsupportedAddressType = errors.New("unsupported address type") + // ErrRemoteNotReady is returned when the server-side stream fails to signal readiness. + ErrRemoteNotReady = errors.New("remote not ready") ) // Client handles local SOCKS5 connections and tunnels them to the server. type Client struct { - ln link.Link - cipher *crypto.Cipher - conn *muxconn.Conn - session *smux.Session - sessMu sync.RWMutex + ln link.Link + cipher *crypto.Cipher + conn *muxconn.Conn + session *smux.Session + sessMu sync.RWMutex dnsServer string } @@ -63,7 +73,13 @@ func Run( vp8FPS int, vp8BatchSize int, ) error { - return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize) + return RunWithReady( + ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, + dnsServer, socksUser, socksPass, nil, + videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, + videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, + vp8FPS, vp8BatchSize, + ) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. @@ -118,7 +134,7 @@ func RunWithReady( if err != nil { return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } - defer listener.Close() + defer func() { _ = listener.Close() }() logger.Infof("SOCKS5 server listening on %s", localAddr) @@ -126,17 +142,10 @@ func RunWithReady( onReady() } - errCh := make(chan error, 1) - go func() { - errCh <- c.acceptLoop(runCtx, listener) - }() + go c.acceptLoop(runCtx, listener) - select { - case <-runCtx.Done(): - return nil - case err := <-errCh: - return err - } + <-runCtx.Done() + return nil } func (c *Client) bringUpLink( @@ -227,8 +236,6 @@ func (c *Client) handleReconnect() { c.conn = nil } c.sessMu.Unlock() - // New SOCKS5 connections will fail until the link comes back up; the - // caller will reissue them. Existing streams die with the smux session. c.conn = muxconn.New(c.ln, c.cipher) sess, err := smux.Client(c.conn, smuxConfig()) if err != nil { @@ -260,7 +267,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) { return nil, fmt.Errorf("failed to decode key: %w", err) } if len(key) != 32 { - return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key)) + return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key)) } cipher, err := crypto.NewCipher(string(key)) @@ -279,13 +286,13 @@ func (c *Client) onData(data []byte) { } } -func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error { +func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) { for { conn, err := ln.Accept() if err != nil { select { case <-ctx.Done(): - return nil + return default: logger.Warnf("Accept error: %v", err) continue @@ -295,8 +302,8 @@ func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error { } } -func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { - defer conn.Close() +func (c *Client) handleSocks5(_ context.Context, conn net.Conn) { + defer func() { _ = conn.Close() }() if err := c.socks5Handshake(conn); err != nil { return @@ -315,38 +322,25 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { return } + c.tunnel(conn, sess, targetAddr, targetPort) +} + +func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) { stream, err := sess.OpenStream() if err != nil { logger.Warnf("OpenStream failed: %v", err) _, _ = conn.Write(replyHostUnreachable()) return } - defer stream.Close() + defer func() { _ = stream.Close() }() logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort) - connectReq, _ := json.Marshal(map[string]any{ - "cmd": "connect", - "addr": targetAddr, - "port": targetPort, - }) - - _ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if _, err := stream.Write(connectReq); err != nil { - logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err) + if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil { + logger.Warnf("sid=%d connect failed: %v", stream.ID(), err) _, _ = conn.Write(replyHostUnreachable()) return } - _ = stream.SetWriteDeadline(time.Time{}) - - ack := make([]byte, 1) - _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) - if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 { - logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack) - _, _ = conn.Write(replyHostUnreachable()) - return - } - _ = stream.SetReadDeadline(time.Time{}) if _, err := conn.Write(replySuccess()); err != nil { return @@ -357,24 +351,47 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) { _ = stream.Close() }() _, _ = io.Copy(conn, stream) +} - _ = ctx // keep signature +func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error { + connectReq, err := json.Marshal(map[string]any{ + "cmd": "connect", + "addr": targetAddr, + "port": targetPort, + }) + if err != nil { + return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err) + } + + _ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if _, err := stream.Write(connectReq); err != nil { + return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err) + } + _ = stream.SetWriteDeadline(time.Time{}) + + ack := make([]byte, 1) + _ = stream.SetReadDeadline(time.Now().Add(15 * time.Second)) + if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 { + return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack) + } + _ = stream.SetReadDeadline(time.Time{}) + return nil } func (c *Client) socks5Handshake(conn net.Conn) error { buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { - return err + return fmt.Errorf("read socks5 header: %w", err) } if buf[0] != 5 { - return fmt.Errorf("invalid socks version: %d", buf[0]) + return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0]) } methods := make([]byte, buf[1]) if _, err := io.ReadFull(conn, methods); err != nil { - return err + return fmt.Errorf("read socks5 methods: %w", err) } if _, err := conn.Write([]byte{5, 0}); err != nil { - return err + return fmt.Errorf("write socks5 auth: %w", err) } return nil } @@ -382,43 +399,49 @@ func (c *Client) socks5Handshake(conn net.Conn) error { func (c *Client) socks5Request(conn net.Conn) (string, int, error) { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { - return "", 0, err + return "", 0, fmt.Errorf("read socks5 request: %w", err) } if header[1] != 1 { - return "", 0, fmt.Errorf("unsupported socks command: %d", header[1]) + return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1]) } - var addr string - switch header[3] { - case 1: // IPv4 - buf := make([]byte, 4) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", 0, err - } - addr = net.IP(buf).String() - case 3: // Domain - lenBuf := make([]byte, 1) - if _, err := io.ReadFull(conn, lenBuf); err != nil { - return "", 0, err - } - buf := make([]byte, lenBuf[0]) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", 0, err - } - addr = string(buf) - default: - return "", 0, fmt.Errorf("unsupported address type: %d", header[3]) + addr, err := c.readSocks5Addr(conn, header[3]) + if err != nil { + return "", 0, err } portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { - return "", 0, err + return "", 0, fmt.Errorf("read socks5 port: %w", err) } port := int(binary.BigEndian.Uint16(portBuf)) return addr, port, nil } +func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) { + switch addrType { + case 1: // IPv4 + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", fmt.Errorf("read socks5 ipv4: %w", err) + } + return net.IP(buf).String(), nil + case 3: // Domain + lenBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return "", fmt.Errorf("read socks5 domain len: %w", err) + } + buf := make([]byte, lenBuf[0]) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", fmt.Errorf("read socks5 domain: %w", err) + } + return string(buf), nil + default: + return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType) + } +} + func replySuccess() []byte { return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0} } diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go index 0bb9cea..715df06 100644 --- a/internal/link/direct/direct.go +++ b/internal/link/direct/direct.go @@ -16,25 +16,25 @@ type directLink struct { // New creates a direct link that forwards bytes to the selected transport. func New(ctx context.Context, cfg link.Config) (link.Link, error) { tr, err := transport.New(ctx, cfg.Transport, transport.Config{ - Carrier: cfg.Carrier, - RoomURL: cfg.RoomURL, - Name: cfg.Name, - OnData: cfg.OnData, - DNSServer: cfg.DNSServer, - ProxyAddr: cfg.ProxyAddr, - ProxyPort: cfg.ProxyPort, - VideoWidth: cfg.VideoWidth, - VideoHeight: cfg.VideoHeight, - VideoFPS: cfg.VideoFPS, - VideoBitrate: cfg.VideoBitrate, - VideoHW: cfg.VideoHW, + Carrier: cfg.Carrier, + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + VideoWidth: cfg.VideoWidth, + VideoHeight: cfg.VideoHeight, + VideoFPS: cfg.VideoFPS, + VideoBitrate: cfg.VideoBitrate, + VideoHW: cfg.VideoHW, VideoQRSize: cfg.VideoQRSize, VideoQRRecovery: cfg.VideoQRRecovery, VideoCodec: cfg.VideoCodec, VideoTileModule: cfg.VideoTileModule, VideoTileRS: cfg.VideoTileRS, - VP8FPS: cfg.VP8FPS, - VP8BatchSize: cfg.VP8BatchSize, + VP8FPS: cfg.VP8FPS, + VP8BatchSize: cfg.VP8BatchSize, }) if err != nil { return nil, fmt.Errorf("create transport for direct link: %w", err) @@ -43,9 +43,27 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) { return &directLink{transport: tr}, nil } -func (d *directLink) Connect(ctx context.Context) error { return d.transport.Connect(ctx) } -func (d *directLink) Send(data []byte) error { return d.transport.Send(data) } -func (d *directLink) Close() error { return d.transport.Close() } +func (d *directLink) Connect(ctx context.Context) error { + if err := d.transport.Connect(ctx); err != nil { + return fmt.Errorf("transport connect: %w", err) + } + return nil +} + +func (d *directLink) Send(data []byte) error { + if err := d.transport.Send(data); err != nil { + return fmt.Errorf("transport send: %w", err) + } + return nil +} + +func (d *directLink) Close() error { + if err := d.transport.Close(); err != nil { + return fmt.Errorf("transport close: %w", err) + } + return nil +} + func (d *directLink) SetReconnectCallback(cb func()) { d.transport.SetReconnectCallback(cb) } func (d *directLink) SetShouldReconnect(fn func() bool) { d.transport.SetShouldReconnect(fn) } func (d *directLink) SetEndedCallback(cb func(string)) { d.transport.SetEndedCallback(cb) } diff --git a/internal/link/link.go b/internal/link/link.go index 704e388..fc30a4e 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -50,6 +50,7 @@ type Config struct { // Factory creates a link instance. type Factory func(ctx context.Context, cfg Config) (Link, error) +//nolint:gochecknoglobals var registry = make(map[string]Factory) // Register adds a link factory to the registry. diff --git a/internal/muxconn/conn.go b/internal/muxconn/conn.go index b895610..3f651a5 100644 --- a/internal/muxconn/conn.go +++ b/internal/muxconn/conn.go @@ -17,6 +17,7 @@ package muxconn import ( "errors" + "fmt" "io" "sync" "time" @@ -92,10 +93,10 @@ func (c *Conn) Write(p []byte) (int, error) { enc, err := c.cipher.Encrypt(p) if err != nil { - return 0, err + return 0, fmt.Errorf("encrypt: %w", err) } if err := c.ln.Send(enc); err != nil { - return 0, err + return 0, fmt.Errorf("send: %w", err) } return len(p), nil } diff --git a/internal/provider/jazz/peer.go b/internal/provider/jazz/peer.go index 90c5228..da66cd2 100644 --- a/internal/provider/jazz/peer.go +++ b/internal/provider/jazz/peer.go @@ -24,6 +24,13 @@ const ( sendDelay = 2 * time.Millisecond ) +var ( + // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. + ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized") + // ErrSubscriberMediaTimeout is returned when the subscriber media is not ready within the timeout period. + ErrSubscriberMediaTimeout = errors.New("subscriber media timeout") +) + // Peer represents a SaluteJazz WebRTC connection. type Peer struct { name string @@ -135,23 +142,23 @@ func (p *Peer) attachPendingVideoTracks() error { return nil } -// Connect starts the WebRTC connection process. -func (p *Peer) Connect(ctx context.Context) error { - p.closed.Store(false) - p.resetMediaState() - - config := webrtc.Configuration{ +func defaultWebRTCConfig() webrtc.Configuration { + return webrtc.Configuration{ ICEServers: []webrtc.ICEServer{}, SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, BundlePolicy: webrtc.BundlePolicyMaxBundle, } +} - settingEngine := webrtc.SettingEngine{} +func (p *Peer) buildAPI() *webrtc.API { + se := webrtc.SettingEngine{} if protect.Protector != nil { - settingEngine.SetICEProxyDialer(protect.NewProxyDialer()) + se.SetICEProxyDialer(protect.NewProxyDialer()) } - api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine)) + return webrtc.NewAPI(webrtc.WithSettingEngine(se)) +} +func (p *Peer) createPeerConnections(api *webrtc.API, config webrtc.Configuration) error { var err error p.pcSub, err = api.NewPeerConnection(config) if err != nil { @@ -162,7 +169,6 @@ func (p *Peer) Connect(ctx context.Context) error { if track.Kind() != webrtc.RTPCodecTypeVideo { return } - if cb := p.videoTrackHandler(); cb != nil { cb(track, receiver) } @@ -173,28 +179,63 @@ func (p *Peer) Connect(ctx context.Context) error { return fmt.Errorf("create publisher pc: %w", err) } p.pcPub.OnConnectionStateChange(p.onPublisherConnectionStateChange) + return nil +} +func (p *Peer) createDataChannel() (chan struct{}, error) { + var err error + p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{ + Ordered: func() *bool { v := true; return &v }(), + }) + if err != nil { + return nil, fmt.Errorf("create datachannel: %w", err) + } + dcReady := make(chan struct{}) + p.setupDataChannelHandlers(dcReady) + return dcReady, nil +} + +func (p *Peer) waitForReady(ctx context.Context, dcReady chan struct{}) error { + if dcReady != nil { + select { + case <-dcReady: + return nil + case <-time.After(30 * time.Second): + return provider.ErrDataChannelTimeout + case <-ctx.Done(): + return fmt.Errorf("connect cancelled: %w", ctx.Err()) + } + } + return p.waitForMediaReady(ctx, 30*time.Second) +} + +// Connect starts the WebRTC connection process. +func (p *Peer) Connect(ctx context.Context) error { + p.closed.Store(false) + p.resetMediaState() + + api := p.buildAPI() + config := defaultWebRTCConfig() + + if err := p.createPeerConnections(api, config); err != nil { + return err + } if err := p.attachPendingVideoTracks(); err != nil { return err } var dcReady chan struct{} if p.onData != nil { - p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{ - Ordered: func() *bool { v := true; return &v }(), - }) + var err error + dcReady, err = p.createDataChannel() if err != nil { - return fmt.Errorf("create datachannel: %w", err) + return err } - - dcReady = make(chan struct{}) - p.setupDataChannelHandlers(dcReady) } if err := p.dialWebSocket(); err != nil { return err } - if err := p.sendJoin(); err != nil { return err } @@ -205,18 +246,7 @@ func (p *Peer) Connect(ctx context.Context) error { p.handleSignaling(ctx) }() - if p.onData != nil { - select { - case <-dcReady: - return nil - case <-time.After(30 * time.Second): - return provider.ErrDataChannelTimeout - case <-ctx.Done(): - return fmt.Errorf("connect cancelled: %w", ctx.Err()) - } - } - - return p.waitForMediaReady(ctx, 30*time.Second) + return p.waitForReady(ctx, dcReady) } func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) error { @@ -226,7 +256,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err select { case <-p.subscriberConn: case <-timer.C: - return fmt.Errorf("subscriber media timeout") + return ErrSubscriberMediaTimeout case <-ctx.Done(): return fmt.Errorf("connect cancelled: %w", ctx.Err()) } @@ -320,30 +350,38 @@ func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}) { } func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) { - if state == webrtc.PeerConnectionStateConnected { + switch state { + case webrtc.PeerConnectionStateConnected: p.subscriberReady.Store(true) closeSignal(p.subscriberConn) - } else if state == webrtc.PeerConnectionStateDisconnected || - state == webrtc.PeerConnectionStateFailed || - state == webrtc.PeerConnectionStateClosed { + case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed: p.subscriberReady.Store(false) - if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) { + if !p.closed.Load() { p.queueReconnect() } + case webrtc.PeerConnectionStateClosed: + p.subscriberReady.Store(false) + case webrtc.PeerConnectionStateUnknown, + webrtc.PeerConnectionStateNew, + webrtc.PeerConnectionStateConnecting: } } func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) { - if state == webrtc.PeerConnectionStateConnected { + switch state { + case webrtc.PeerConnectionStateConnected: p.publisherReady.Store(true) closeSignal(p.publisherConn) - } else if state == webrtc.PeerConnectionStateDisconnected || - state == webrtc.PeerConnectionStateFailed || - state == webrtc.PeerConnectionStateClosed { + case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed: p.publisherReady.Store(false) - if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) { + if !p.closed.Load() { p.queueReconnect() } + case webrtc.PeerConnectionStateClosed: + p.publisherReady.Store(false) + case webrtc.PeerConnectionStateUnknown, + webrtc.PeerConnectionStateNew, + webrtc.PeerConnectionStateConnecting: } } @@ -651,11 +689,6 @@ func (p *Peer) Close() error { return nil } -var ( - // ErrPublisherNotInitialized is returned when the publisher peer connection is not set up. - ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized") -) - // AddVideoTrack adds a video track to the publisher peer connection. func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error { p.videoTrackMu.Lock() diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 3820b85..031835d 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -22,8 +22,6 @@ var ( ) // Provider defines the standard interface for WebRTC connection handlers. -// -//nolint:interfacebloat // All methods are necessary for provider abstraction. type Provider interface { Connect(ctx context.Context) error Send(data []byte) error diff --git a/internal/provider/telemost/peer.go b/internal/provider/telemost/peer.go index 0b36495..16182bb 100644 --- a/internal/provider/telemost/peer.go +++ b/internal/provider/telemost/peer.go @@ -42,6 +42,8 @@ var ( ErrSessionClosed = errors.New("session closed") // ErrPeerClosed is returned when the peer is closed. ErrPeerClosed = errors.New("peer closed") + // ErrSubscriberMediaTimeout is returned when subscriber media is not ready within the timeout period. + ErrSubscriberMediaTimeout = errors.New("subscriber media timeout") ) // TrafficShape defines the parameters for outgoing traffic control. @@ -288,7 +290,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err select { case <-p.subscriberConn: case <-timer.C: - return fmt.Errorf("subscriber media timeout") + return ErrSubscriberMediaTimeout case <-ctx.Done(): return fmt.Errorf("connect context cancelled: %w", ctx.Err()) } @@ -314,7 +316,8 @@ func (p *Peer) setupPeerConnections(config webrtc.Configuration) error { return } - logger.Infof("telemost remote video track: codec=%s stream=%s track=%s", track.Codec().MimeType, track.StreamID(), track.ID()) + logger.Infof("telemost remote video track: codec=%s stream=%s track=%s", + track.Codec().MimeType, track.StreamID(), track.ID()) if cb := p.videoTrackHandler(); cb != nil { cb(track, receiver) @@ -342,29 +345,35 @@ func (p *Peer) onConnectionStateChange(state webrtc.PeerConnectionState) { func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) { logger.Debugf("telemost subscriber state: %s", state.String()) - if state == webrtc.PeerConnectionStateConnected { + switch state { + case webrtc.PeerConnectionStateConnected: p.subscriberReady.Store(true) closeSignal(p.subscriberConn) - } else if state == webrtc.PeerConnectionStateDisconnected || - state == webrtc.PeerConnectionStateFailed || - state == webrtc.PeerConnectionStateClosed { + case webrtc.PeerConnectionStateDisconnected, + webrtc.PeerConnectionStateFailed, + webrtc.PeerConnectionStateClosed: p.subscriberReady.Store(false) + case webrtc.PeerConnectionStateUnknown, + webrtc.PeerConnectionStateNew, + webrtc.PeerConnectionStateConnecting: } - p.onConnectionStateChange(state) } func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) { logger.Debugf("telemost publisher state: %s", state.String()) - if state == webrtc.PeerConnectionStateConnected { + switch state { + case webrtc.PeerConnectionStateConnected: p.publisherReady.Store(true) closeSignal(p.publisherConn) - } else if state == webrtc.PeerConnectionStateDisconnected || - state == webrtc.PeerConnectionStateFailed || - state == webrtc.PeerConnectionStateClosed { + case webrtc.PeerConnectionStateDisconnected, + webrtc.PeerConnectionStateFailed, + webrtc.PeerConnectionStateClosed: p.publisherReady.Store(false) + case webrtc.PeerConnectionStateUnknown, + webrtc.PeerConnectionStateNew, + webrtc.PeerConnectionStateConnecting: } - p.onConnectionStateChange(state) } @@ -656,7 +665,7 @@ func (p *Peer) sendSetSlots() error { p.wsMu.Lock() defer p.wsMu.Unlock() - return p.ws.WriteJSON(map[string]interface{}{ + if err := p.ws.WriteJSON(map[string]interface{}{ "uid": uuid.New().String(), "setSlots": map[string]interface{}{ "slots": []map[string]int{ @@ -670,7 +679,52 @@ func (p *Peer) sendSetSlots() error { "selfViewVisibility": "ON_LOADING_THEN_SHOW", "gridConfig": map[string]interface{}{}, }, - }) + }); err != nil { + return fmt.Errorf("write set slots: %w", err) + } + return nil +} + +func isNonTURNURL(url string) bool { + return url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") +} + +func parseICEURLs(server map[string]interface{}) []string { + var urls []string + switch rawURLs := server["urls"].(type) { + case []interface{}: + for _, rawURL := range rawURLs { + if url, ok := rawURL.(string); ok && isNonTURNURL(url) { + urls = append(urls, url) + } + } + case []string: + for _, url := range rawURLs { + if isNonTURNURL(url) { + urls = append(urls, url) + } + } + } + return urls +} + +func parseICEServer(rawServer interface{}) (webrtc.ICEServer, bool) { + server, ok := rawServer.(map[string]interface{}) + if !ok { + return webrtc.ICEServer{}, false + } + urls := parseICEURLs(server) + if len(urls) == 0 { + return webrtc.ICEServer{}, false + } + ice := webrtc.ICEServer{URLs: urls} + if username, ok := server["username"].(string); ok { + ice.Username = username + } + if credential, ok := server["credential"].(string); ok { + ice.Credential = credential + } + return ice, true } func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) { @@ -686,39 +740,9 @@ func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) { iceServers := make([]webrtc.ICEServer, 0, len(rawServers)) for _, rawServer := range rawServers { - server, ok := rawServer.(map[string]interface{}) - if !ok { - continue + if ice, ok := parseICEServer(rawServer); ok { + iceServers = append(iceServers, ice) } - - var urls []string - switch rawURLs := server["urls"].(type) { - case []interface{}: - for _, rawURL := range rawURLs { - if url, ok := rawURL.(string); ok && url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") { - urls = append(urls, url) - } - } - case []string: - for _, url := range rawURLs { - if !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") { - urls = append(urls, url) - } - } - } - - if len(urls) == 0 { - continue - } - - ice := webrtc.ICEServer{URLs: urls} - if username, ok := server["username"].(string); ok { - ice.Username = username - } - if credential, ok := server["credential"].(string); ok { - ice.Credential = credential - } - iceServers = append(iceServers, ice) } if len(iceServers) == 0 { diff --git a/internal/server/server.go b/internal/server/server.go index 5e42a69..455ee27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -22,6 +22,8 @@ import ( ) var ( + // ErrKeyRequired is returned when no encryption key is provided. + ErrKeyRequired = errors.New("key required (use -key )") // ErrKeySize is returned when the encryption key is not 32 bytes. ErrKeySize = errors.New("key must be 32 bytes") // ErrSocks5AuthFailed is returned when SOCKS5 authentication fails. @@ -100,17 +102,17 @@ func Run( return err } - err = s.serve(runCtx) + s.serve(runCtx) s.shutdown() s.wg.Wait() - return err + return nil } func setupCipher(keyHex string) (*crypto.Cipher, error) { if keyHex == "" { - return nil, errors.New("key required (use -key )") + return nil, ErrKeyRequired } key, err := hex.DecodeString(keyHex) @@ -252,10 +254,12 @@ func (s *Server) onData(data []byte) { // serve drives the smux Accept loop, spawning a tunnel per inbound stream. // The loop tolerates session bounces (reconnects) by waiting until a fresh // session is installed instead of terminating the server. -func (s *Server) serve(ctx context.Context) error { +func (s *Server) serve(ctx context.Context) { for { - if ctx.Err() != nil { - return nil + select { + case <-ctx.Done(): + return + default: } s.sessMu.RLock() @@ -264,7 +268,7 @@ func (s *Server) serve(ctx context.Context) error { if sess == nil { select { case <-ctx.Done(): - return nil + return case <-time.After(50 * time.Millisecond): continue } @@ -272,10 +276,10 @@ func (s *Server) serve(ctx context.Context) error { stream, err := sess.AcceptStream() if err != nil { - // Session is torn down (reconnect or close). If we're shutting - // down, exit; otherwise wait for a new session and retry. - if ctx.Err() != nil { - return nil + select { + case <-ctx.Done(): + return + default: } logger.Infof("AcceptStream returned %v — waiting for new session", err) time.Sleep(100 * time.Millisecond) @@ -305,7 +309,7 @@ func (s *Server) shutdown() { } func (s *Server) handleStream(_ context.Context, stream *smux.Stream) { - defer stream.Close() + defer func() { _ = stream.Close() }() // Read the connect JSON. The client writes the whole JSON in one // stream.Write so it usually arrives intact; tolerate fragmentation @@ -356,7 +360,7 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err) return } - defer conn.Close() + defer func() { _ = conn.Close() }() logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed) diff --git a/internal/transport/datachannel/transport.go b/internal/transport/datachannel/transport.go index 8b61848..965bb04 100644 --- a/internal/transport/datachannel/transport.go +++ b/internal/transport/datachannel/transport.go @@ -44,17 +44,26 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) // Connect starts the transport connection. func (p *streamTransport) Connect(ctx context.Context) error { - return p.stream.Connect(ctx) + if err := p.stream.Connect(ctx); err != nil { + return fmt.Errorf("stream connect: %w", err) + } + return nil } // Send transmits data through the transport. func (p *streamTransport) Send(data []byte) error { - return p.stream.Send(data) + if err := p.stream.Send(data); err != nil { + return fmt.Errorf("stream send: %w", err) + } + return nil } // Close terminates the transport. func (p *streamTransport) Close() error { - return p.stream.Close() + if err := p.stream.Close(); err != nil { + return fmt.Errorf("stream close: %w", err) + } + return nil } // SetReconnectCallback registers reconnect handling. diff --git a/internal/transport/seichannel/h264.go b/internal/transport/seichannel/h264.go index 1d6f993..200a839 100644 --- a/internal/transport/seichannel/h264.go +++ b/internal/transport/seichannel/h264.go @@ -3,11 +3,20 @@ package seichannel import ( "bytes" "encoding/hex" + "errors" "fmt" "github.com/pion/webrtc/v4/pkg/media/h264reader" ) +var ( + // ErrSEIPayloadTruncated is returned when the SEI payload is shorter than expected. + ErrSEIPayloadTruncated = errors.New("sei payload truncated") + // ErrSEIValueTruncated is returned when reading a SEI length-value runs past the buffer. + ErrSEIValueTruncated = errors.New("sei value truncated") +) + +//nolint:gochecknoglobals var ( videoSEIUUID = [16]byte{ 0x5d, 0xc0, 0x3b, 0xa8, @@ -21,19 +30,16 @@ var ( baseIDR = mustDecodeHex("6588843a2628000902e0") ) -func buildVideoAccessUnit(payload []byte) ([]byte, error) { +func buildVideoAccessUnit(payload []byte) []byte { out := make([]byte, 0, len(baseSPS)+len(basePPS)+len(baseIDR)+64+len(payload)) out = appendStartCode(out, baseSPS) out = appendStartCode(out, basePPS) if len(payload) > 0 { - sei, err := buildSEINAL(payload) - if err != nil { - return nil, err - } + sei := buildSEINAL(payload) out = appendStartCode(out, sei) } out = appendStartCode(out, baseIDR) - return out, nil + return out } func extractVideoPayloads(accessUnit []byte) ([][]byte, error) { @@ -63,7 +69,7 @@ func extractVideoPayloads(accessUnit []byte) ([][]byte, error) { } } -func buildSEINAL(payload []byte) ([]byte, error) { +func buildSEINAL(payload []byte) []byte { userData := make([]byte, 0, len(videoSEIUUID)+len(payload)) userData = append(userData, videoSEIUUID[:]...) userData = append(userData, payload...) @@ -74,9 +80,11 @@ func buildSEINAL(payload []byte) ([]byte, error) { rbsp = append(rbsp, userData...) rbsp = append(rbsp, 0x80) - out := []byte{0x06} - out = append(out, escapeRBSP(rbsp)...) - return out, nil + escaped := escapeRBSP(rbsp) + out := make([]byte, 0, 1+len(escaped)) + out = append(out, 0x06) + out = append(out, escaped...) + return out } func extractTransportSEI(rbsp []byte) ([][]byte, error) { @@ -101,7 +109,7 @@ func extractTransportSEI(rbsp []byte) ([][]byte, error) { pos = next if pos+payloadSize > len(data) { - return nil, fmt.Errorf("sei payload truncated") + return nil, ErrSEIPayloadTruncated } payload := data[pos : pos+payloadSize] @@ -127,14 +135,14 @@ func appendSEIValue(dst []byte, value int) []byte { dst = append(dst, 0xff) value -= 0xff } - return append(dst, byte(value)) + return append(dst, byte(value)) //nolint:gosec } func consumeSEIValue(data []byte, pos int) (int, int, error) { value := 0 for { if pos >= len(data) { - return 0, pos, fmt.Errorf("sei value truncated") + return 0, pos, ErrSEIValueTruncated } b := int(data[pos]) pos++ @@ -170,11 +178,11 @@ func escapeRBSP(rbsp []byte) []byte { func unescapeRBSP(rbsp []byte) []byte { out := make([]byte, 0, len(rbsp)) - for i := 0; i < len(rbsp); i++ { - if i >= 2 && rbsp[i] == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 { + for i, b := range rbsp { + if i >= 2 && b == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 { continue } - out = append(out, rbsp[i]) + out = append(out, b) } return out } diff --git a/internal/transport/seichannel/transport.go b/internal/transport/seichannel/transport.go index 19bb40b..663454f 100644 --- a/internal/transport/seichannel/transport.go +++ b/internal/transport/seichannel/transport.go @@ -40,6 +40,18 @@ var ( ErrAckTimeout = errors.New("seichannel ack timeout") // ErrTransportClosed is returned when operations are attempted on a closed transport. ErrTransportClosed = errors.New("seichannel transport closed") + // ErrFrameTooShort is returned when the received frame is too short to decode. + ErrFrameTooShort = errors.New("frame too short") + // ErrUnexpectedMagic is returned when the frame magic bytes do not match. + ErrUnexpectedMagic = errors.New("unexpected frame magic") + // ErrUnexpectedVersion is returned when the frame protocol version does not match. + ErrUnexpectedVersion = errors.New("unexpected frame version") + // ErrAckTooShort is returned when the ack frame is shorter than expected. + ErrAckTooShort = errors.New("ack frame too short") + // ErrDataTooShort is returned when the data frame is shorter than expected. + ErrDataTooShort = errors.New("data frame too short") + // ErrUnexpectedFrameType is returned for unknown frame type bytes. + ErrUnexpectedFrameType = errors.New("unexpected frame type") ) type transportFrame struct { @@ -144,7 +156,7 @@ func (p *streamTransport) Connect(ctx context.Context) error { defer cancel() if err := p.stream.Connect(connectCtx); err != nil { - return err + return fmt.Errorf("connect stream: %w", err) } p.startWriter.Do(func() { @@ -178,7 +190,7 @@ func (p *streamTransport) Send(data []byte) error { p.ackMu.Unlock() }() - for attempt := 0; attempt < maxSendAttempts; attempt++ { + for range maxSendAttempts { for idx, fragment := range fragments { frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment) if err := p.enqueueFrame(frame, false); err != nil { @@ -210,7 +222,9 @@ func (p *streamTransport) Close() error { if p.writerUp.Load() { <-p.writerDone } - return p.stream.Close() + if err := p.stream.Close(); err != nil { + return fmt.Errorf("close stream: %w", err) + } } return nil } @@ -256,10 +270,7 @@ func (p *streamTransport) writerLoop() { ticker := time.NewTicker(defaultFrameInterval) defer ticker.Stop() - idle, err := buildVideoAccessUnit(nil) - if err != nil { - return - } + idle := buildVideoAccessUnit(nil) for { select { @@ -273,10 +284,7 @@ func (p *streamTransport) writerLoop() { sample := idle if payload != nil { - sample, err = buildVideoAccessUnit(payload) - if err != nil { - continue - } + sample = buildVideoAccessUnit(payload) } _ = p.track.WriteSample(media.Sample{ @@ -371,14 +379,7 @@ func (p *streamTransport) handleSample(sample []byte) { } } -func (p *streamTransport) handleInboundFrame(frame transportFrame) { - p.recvMu.Lock() - if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { - p.recvMu.Unlock() - p.sendAck(frame.seq, frame.crc) - return - } - +func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) { msg, ok := p.inbound[frame.seq] if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { msg = &inboundMessage{ @@ -389,33 +390,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) { } p.inbound[frame.seq] = msg } - if int(frame.fragIdx) >= len(msg.frags) { - p.recvMu.Unlock() - return + return nil, false } - if msg.frags[frame.fragIdx] == nil { chunk := make([]byte, len(frame.payload)) copy(chunk, frame.payload) msg.frags[frame.fragIdx] = chunk msg.remain-- } + return msg, msg.remain == 0 +} - if msg.remain > 0 { +func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte { + data := make([]byte, 0, msg.totalLen) + for _, frag := range msg.frags { + data = append(data, frag...) + } + if uint32(len(data)) > msg.totalLen { //nolint:gosec + data = data[:msg.totalLen] + } + return data +} + +func (p *streamTransport) handleInboundFrame(frame transportFrame) { + p.recvMu.Lock() + if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { + p.recvMu.Unlock() + p.sendAck(frame.seq, frame.crc) + return + } + + msg, complete := p.upsertInbound(frame) + if msg == nil || !complete { p.recvMu.Unlock() return } delete(p.inbound, frame.seq) - data := make([]byte, 0, msg.totalLen) - for _, frag := range msg.frags { - data = append(data, frag...) - } - - if uint32(len(data)) > msg.totalLen { - data = data[:msg.totalLen] - } + data := p.assembleMessage(msg) if crc32.ChecksumIEEE(data) != msg.crc { p.recvMu.Unlock() @@ -480,9 +493,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload out[5] = frameTypeData binary.BigEndian.PutUint32(out[6:10], seq) binary.BigEndian.PutUint32(out[10:14], crc) - binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) - binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) - binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec + binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec copy(out[22:], payload) return out } @@ -499,27 +512,27 @@ func encodeAckFrame(seq, crc uint32) []byte { func decodeTransportFrame(data []byte) (transportFrame, error) { if len(data) < 6 { - return transportFrame{}, fmt.Errorf("frame too short") + return transportFrame{}, ErrFrameTooShort } if binary.BigEndian.Uint32(data[0:4]) != protocolMagic { - return transportFrame{}, fmt.Errorf("unexpected frame magic") + return transportFrame{}, ErrUnexpectedMagic } if data[4] != protocolVersion { - return transportFrame{}, fmt.Errorf("unexpected frame version") + return transportFrame{}, ErrUnexpectedVersion } frame := transportFrame{typ: data[5]} switch frame.typ { case frameTypeAck: if len(data) < 14 { - return transportFrame{}, fmt.Errorf("ack too short") + return transportFrame{}, ErrAckTooShort } frame.seq = binary.BigEndian.Uint32(data[6:10]) frame.crc = binary.BigEndian.Uint32(data[10:14]) return frame, nil case frameTypeData: if len(data) < 22 { - return transportFrame{}, fmt.Errorf("data too short") + return transportFrame{}, ErrDataTooShort } frame.seq = binary.BigEndian.Uint32(data[6:10]) frame.crc = binary.BigEndian.Uint32(data[10:14]) @@ -529,6 +542,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) { frame.payload = append([]byte(nil), data[22:]...) return frame, nil default: - return transportFrame{}, fmt.Errorf("unexpected frame type") + return transportFrame{}, ErrUnexpectedFrameType } } diff --git a/internal/transport/seichannel/transport_test.go b/internal/transport/seichannel/transport_test.go index 82d7c25..26556c4 100644 --- a/internal/transport/seichannel/transport_test.go +++ b/internal/transport/seichannel/transport_test.go @@ -7,10 +7,7 @@ import ( func TestSEIRoundTrip(t *testing.T) { payload := []byte("hello over seichannel") - accessUnit, err := buildVideoAccessUnit(payload) - if err != nil { - t.Fatalf("buildVideoAccessUnit failed: %v", err) - } + accessUnit := buildVideoAccessUnit(payload) got, err := extractVideoPayloads(accessUnit) if err != nil { diff --git a/internal/transport/transport.go b/internal/transport/transport.go index aa0b464..f737021 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -58,6 +58,7 @@ type Config struct { // Factory creates a transport instance. type Factory func(ctx context.Context, cfg Config) (Transport, error) +//nolint:gochecknoglobals var registry = make(map[string]Factory) // Register adds a transport factory to the registry. diff --git a/internal/transport/videochannel/ffmpeg.go b/internal/transport/videochannel/ffmpeg.go index 6b0318a..7099200 100644 --- a/internal/transport/videochannel/ffmpeg.go +++ b/internal/transport/videochannel/ffmpeg.go @@ -2,11 +2,13 @@ package videochannel import ( "bytes" + "context" "encoding/binary" "errors" "fmt" "io" "os/exec" + "strconv" "strings" "sync" "sync/atomic" @@ -27,6 +29,12 @@ var ( ErrFFmpegUnavailable = errors.New("ffmpeg is required for videochannel") // ErrUnsupportedVideoCodec is returned when videochannel cannot decode the negotiated codec. ErrUnsupportedVideoCodec = errors.New("unsupported video codec") + // ErrEncoderTimeout is returned when the encoder does not produce a frame within the deadline. + ErrEncoderTimeout = errors.New("encoder timeout") + // ErrPopFrameTimeout is returned when no decoded frame is available within the deadline. + ErrPopFrameTimeout = errors.New("pop frame timeout") + // ErrUnexpectedFrameSize is returned when the raw frame size does not match expectations. + ErrUnexpectedFrameSize = errors.New("unexpected encoder frame size") ) type codecSpec struct { @@ -38,8 +46,7 @@ type codecSpec struct { encodeArgs []string } -func codecSpecForCarrier(carrier string) codecSpec { - // Natural default for most WebRTC providers +func codecSpecForCarrier(_ string) codecSpec { return vp8CodecSpec() } @@ -120,6 +127,49 @@ func vp8CodecSpec() codecSpec { } } +func resolveEncoderCodec(spec codecSpec, hw string) string { + if hw != "nvenc" { + return spec.encoder + } + switch spec.mimeType { + case webrtc.MimeTypeH264: + return "h264_nvenc" + case webrtc.MimeTypeVP8: + return "vp8_nvenc" + case webrtc.MimeTypeVP9: + return "vp9_nvenc" + case webrtc.MimeTypeAV1: + return "av1_nvenc" + default: + return spec.encoder + } +} + +func buildEncoderArgs(spec codecSpec, vcodec string, width, height, fps int, bitrate string) []string { + args := []string{ + "-loglevel", "error", "-threads", "1", + "-f", "rawvideo", + "-pix_fmt", "gray", + "-video_size", strconv.Itoa(width) + "x" + strconv.Itoa(height), + "-framerate", strconv.Itoa(fps), + "-i", "pipe:0", + "-an", + } + + if strings.HasSuffix(vcodec, "_nvenc") { + args = append(args, "-c:v", vcodec, "-preset", "p1", "-tune", "ull", "-rc", "vbr") + } else { + args = append(args, spec.encodeArgs...) + } + + args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate) + + if spec.mimeType == webrtc.MimeTypeH264 { + return append(args, "-f", "h264", "pipe:1") + } + return append(args, "-f", "ivf", "pipe:1") +} + type ffmpegEncoder struct { cmd *exec.Cmd stdin io.WriteCloser @@ -134,62 +184,20 @@ type ffmpegEncoder struct { err error } -func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string) (*ffmpegEncoder, error) { +func newFFmpegEncoder( + ctx context.Context, + spec codecSpec, + width, height, fps int, + bitrate, hw string, +) (*ffmpegEncoder, error) { if _, err := exec.LookPath("ffmpeg"); err != nil { return nil, ErrFFmpegUnavailable } - args := []string{"-loglevel", "error", "-threads", "1"} + vcodec := resolveEncoderCodec(spec, hw) + args := buildEncoderArgs(spec, vcodec, width, height, fps, bitrate) - // Determine encoder binary based on HW flag - vcodec := spec.encoder - if hw == "nvenc" { - switch spec.mimeType { - case webrtc.MimeTypeH264: - vcodec = "h264_nvenc" - case webrtc.MimeTypeVP8: - vcodec = "vp8_nvenc" - case webrtc.MimeTypeVP9: - vcodec = "vp9_nvenc" - case webrtc.MimeTypeAV1: - vcodec = "av1_nvenc" - } - } - - inputPixFmt := "gray" - frameSize := width * height - - args = append(args, - "-f", "rawvideo", - "-pix_fmt", inputPixFmt, - "-video_size", fmt.Sprintf("%dx%d", width, height), - "-framerate", fmt.Sprintf("%d", fps), - "-i", "pipe:0", - "-an", - ) - - // Apply hardware specific flags if using NVENC - if strings.HasSuffix(vcodec, "_nvenc") { - args = append(args, - "-c:v", vcodec, - "-preset", "p1", - "-tune", "ull", - "-rc", "vbr", - ) - } else { - // Use software encoder args from spec - args = append(args, spec.encodeArgs...) - } - - args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate) - - if spec.mimeType == webrtc.MimeTypeH264 { - args = append(args, "-f", "h264", "pipe:1") - } else { - args = append(args, "-f", "ivf", "pipe:1") - } - - cmd := exec.Command("ffmpeg", args...) + cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec stdin, err := cmd.StdinPipe() if err != nil { return nil, fmt.Errorf("encoder stdin: %w", err) @@ -212,7 +220,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string frames: make(chan []byte, 8), width: width, height: height, - frameSize: frameSize, + frameSize: width * height, } if spec.mimeType == webrtc.MimeTypeH264 { @@ -225,7 +233,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) { if len(frame) != e.frameSize { - return nil, fmt.Errorf("unexpected encoder frame size: %d (expected %d)", len(frame), e.frameSize) + return nil, fmt.Errorf("%w: got %d expected %d", ErrUnexpectedFrameSize, len(frame), e.frameSize) } if err := e.processErr(); err != nil { return nil, err @@ -244,7 +252,7 @@ func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) { if err := e.processErr(); err != nil { return nil, err } - return nil, fmt.Errorf("encoder timeout") + return nil, ErrEncoderTimeout } } @@ -327,6 +335,43 @@ func (e *ffmpegEncoder) processErr() error { return nil } +func resolveDecoderName(spec codecSpec, hw string) string { + if hw != "nvenc" { + return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/")) + } + switch spec.mimeType { + case webrtc.MimeTypeH264: + return "h264_cuvid" + case webrtc.MimeTypeVP8: + return "vp8_cuvid" + case webrtc.MimeTypeVP9: + return "vp9_cuvid" + default: + return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/")) + } +} + +func buildDecoderArgs(spec codecSpec, decoderName string, width, height int, outputPixFmt string) []string { + args := []string{"-loglevel", "error", "-threads", "1"} + if spec.mimeType == webrtc.MimeTypeH264 { + args = append(args, "-f", "h264") + } else { + args = append(args, "-f", "ivf") + } + + vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt) + return append(args, + "-flags", "low_delay", + "-vcodec", decoderName, + "-i", "pipe:0", + "-an", + "-vf", vfFilter, + "-pix_fmt", outputPixFmt, + "-f", "rawvideo", + "pipe:1", + ) +} + type ffmpegDecoder struct { cmd *exec.Cmd stdin io.WriteCloser @@ -341,46 +386,20 @@ type ffmpegDecoder struct { err error } -func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpegDecoder, error) { +func newFFmpegDecoder( + ctx context.Context, + spec codecSpec, + width, height, fps int, + hw string, +) (*ffmpegDecoder, error) { if _, err := exec.LookPath("ffmpeg"); err != nil { return nil, ErrFFmpegUnavailable } - decoderName := strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/")) - if hw == "nvenc" { - switch spec.mimeType { - case webrtc.MimeTypeH264: - decoderName = "h264_cuvid" - case webrtc.MimeTypeVP8: - decoderName = "vp8_cuvid" - case webrtc.MimeTypeVP9: - decoderName = "vp9_cuvid" - } - } + decoderName := resolveDecoderName(spec, hw) + args := buildDecoderArgs(spec, decoderName, width, height, "gray") - outputPixFmt := "gray" - frameSize := width * height - - args := []string{"-loglevel", "error", "-threads", "1"} - if spec.mimeType == webrtc.MimeTypeH264 { - args = append(args, "-f", "h264") - } else { - args = append(args, "-f", "ivf") - } - - vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt) - args = append(args, - "-flags", "low_delay", - "-vcodec", decoderName, - "-i", "pipe:0", - "-an", - "-vf", vfFilter, - "-pix_fmt", outputPixFmt, - "-f", "rawvideo", - "pipe:1", - ) - - cmd := exec.Command("ffmpeg", args...) + cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec stdin, err := cmd.StdinPipe() if err != nil { return nil, fmt.Errorf("decoder stdin: %w", err) @@ -402,7 +421,7 @@ func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpe stderr: stderr, frames: make(chan []byte, 32), mimeType: spec.mimeType, - frameSize: frameSize, + frameSize: width * height, } if spec.mimeType != webrtc.MimeTypeH264 { @@ -441,7 +460,7 @@ func (d *ffmpegDecoder) PopFrame() ([]byte, error) { } return frame, nil case <-time.After(10 * time.Second): - return nil, fmt.Errorf("pop frame timeout") + return nil, ErrPopFrameTimeout } } @@ -515,9 +534,9 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er binary.LittleEndian.PutUint16(header[4:6], 0) binary.LittleEndian.PutUint16(header[6:8], 32) copy(header[8:12], []byte(fourCC)) - binary.LittleEndian.PutUint16(header[12:14], uint16(width)) - binary.LittleEndian.PutUint16(header[14:16], uint16(height)) - binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate)) + binary.LittleEndian.PutUint16(header[12:14], uint16(width)) //nolint:gosec + binary.LittleEndian.PutUint16(header[14:16], uint16(height)) //nolint:gosec + binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate)) //nolint:gosec binary.LittleEndian.PutUint32(header[20:24], 1) binary.LittleEndian.PutUint32(header[24:28], 0) binary.LittleEndian.PutUint32(header[28:32], 0) @@ -526,7 +545,7 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er func writeIVFFrame(w io.Writer, pts uint64, frame []byte) error { header := make([]byte, 12) - binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame))) + binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame))) //nolint:gosec binary.LittleEndian.PutUint64(header[4:12], pts) if err := writeAll(w, header); err != nil { return err @@ -538,9 +557,10 @@ func writeAll(w io.Writer, data []byte) error { for len(data) > 0 { n, err := w.Write(data) if err != nil { - return err + return fmt.Errorf("write: %w", err) } data = data[n:] } return nil } + diff --git a/internal/transport/videochannel/frame.go b/internal/transport/videochannel/frame.go index cf7f198..8f3e316 100644 --- a/internal/transport/videochannel/frame.go +++ b/internal/transport/videochannel/frame.go @@ -2,7 +2,7 @@ package videochannel import ( "encoding/binary" - "fmt" + "errors" ) const ( @@ -12,6 +12,21 @@ const ( frameTypeAck byte = 2 ) +var ( + // ErrFrameTooShort is returned when the received frame is too short to decode. + ErrFrameTooShort = errors.New("frame too short") + // ErrUnexpectedMagic is returned when the frame magic bytes do not match. + ErrUnexpectedMagic = errors.New("unexpected frame magic") + // ErrUnexpectedVersion is returned when the frame protocol version does not match. + ErrUnexpectedVersion = errors.New("unexpected frame version") + // ErrAckTooShort is returned when the ack frame is shorter than expected. + ErrAckTooShort = errors.New("ack frame too short") + // ErrDataTooShort is returned when the data frame is shorter than expected. + ErrDataTooShort = errors.New("data frame too short") + // ErrUnexpectedFrameType is returned for unknown frame type bytes. + ErrUnexpectedFrameType = errors.New("unexpected frame type") +) + type transportFrame struct { typ byte seq uint32 @@ -56,9 +71,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload out[5] = frameTypeData binary.BigEndian.PutUint32(out[6:10], seq) binary.BigEndian.PutUint32(out[10:14], crc) - binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) - binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) - binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) + binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec + binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec + binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec copy(out[22:], payload) return out } @@ -75,27 +90,27 @@ func encodeAckFrame(seq, crc uint32) []byte { func decodeTransportFrame(data []byte) (transportFrame, error) { if len(data) < 6 { - return transportFrame{}, fmt.Errorf("frame too short") + return transportFrame{}, ErrFrameTooShort } if binary.BigEndian.Uint32(data[0:4]) != protocolMagic { - return transportFrame{}, fmt.Errorf("unexpected frame magic") + return transportFrame{}, ErrUnexpectedMagic } if data[4] != protocolVersion { - return transportFrame{}, fmt.Errorf("unexpected frame version") + return transportFrame{}, ErrUnexpectedVersion } frame := transportFrame{typ: data[5]} switch frame.typ { case frameTypeAck: if len(data) < 14 { - return transportFrame{}, fmt.Errorf("ack too short") + return transportFrame{}, ErrAckTooShort } frame.seq = binary.BigEndian.Uint32(data[6:10]) frame.crc = binary.BigEndian.Uint32(data[10:14]) return frame, nil case frameTypeData: if len(data) < 22 { - return transportFrame{}, fmt.Errorf("data too short") + return transportFrame{}, ErrDataTooShort } frame.seq = binary.BigEndian.Uint32(data[6:10]) frame.crc = binary.BigEndian.Uint32(data[10:14]) @@ -105,6 +120,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) { frame.payload = append([]byte(nil), data[22:]...) return frame, nil default: - return transportFrame{}, fmt.Errorf("unexpected frame type") + return transportFrame{}, ErrUnexpectedFrameType } } diff --git a/internal/transport/videochannel/transport.go b/internal/transport/videochannel/transport.go index 8411217..7d367c9 100644 --- a/internal/transport/videochannel/transport.go +++ b/internal/transport/videochannel/transport.go @@ -70,9 +70,8 @@ type streamTransport struct { videoCodec string videoTileModule int videoTileRS int + runCtx context.Context //nolint:containedctx - // cached encoded idle frame — rendered and encoded once, reused on every tick - // where the outbound queue is empty to avoid re-encoding an identical blank frame. idleFrame []byte idleFrameMu sync.Mutex } @@ -144,6 +143,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) videoCodec: cfg.VideoCodec, videoTileModule: tileModule, videoTileRS: tileRS, + runCtx: ctx, } if err := stream.AddTrack(track); err != nil { @@ -159,14 +159,14 @@ func (p *streamTransport) Connect(ctx context.Context) error { connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout) defer cancel() - encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW) + encoder, err := newFFmpegEncoder(ctx, p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW) if err != nil { - return err + return fmt.Errorf("new encoder: %w", err) } if err := p.stream.Connect(connectCtx); err != nil { _ = encoder.Close() - return err + return fmt.Errorf("connect stream: %w", err) } p.encoderMu.Lock() @@ -212,7 +212,7 @@ func (p *streamTransport) Send(data []byte) error { p.ackMu.Unlock() }() - for attempt := 0; attempt < maxSendAttempts; attempt++ { + for range maxSendAttempts { for idx, fragment := range fragments { frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment) if err := p.enqueueFrame(frame, false); err != nil { @@ -257,7 +257,9 @@ func (p *streamTransport) Close() error { if p.writerUp.Load() { <-p.writerDone } - return p.stream.Close() + if err := p.stream.Close(); err != nil { + return fmt.Errorf("close stream: %w", err) + } } return nil } @@ -301,6 +303,47 @@ func (p *streamTransport) Features() transport.Features { } } +func (p *streamTransport) writeIdleFrame(enc *ffmpegEncoder, frameDuration time.Duration) { + p.idleFrameMu.Lock() + cached := p.idleFrame + p.idleFrameMu.Unlock() + + if cached == nil { + rawFrame, err := p.renderFrame(nil) + if err != nil { + logger.Debugf("videochannel render idle error: %v", err) + return + } + sample, err := enc.EncodeFrame(rawFrame) + if err != nil { + logger.Warnf("videochannel encoder idle error: %v", err) + return + } + p.idleFrameMu.Lock() + p.idleFrame = sample + p.idleFrameMu.Unlock() + cached = sample + } + + _ = p.track.WriteSample(media.Sample{Data: cached, Duration: frameDuration}) +} + +func (p *streamTransport) writePayloadFrame(enc *ffmpegEncoder, payload []byte, frameDuration time.Duration) { + rawFrame, err := p.renderFrame(payload) + if err != nil { + logger.Debugf("videochannel render error: %v", err) + return + } + + sample, err := enc.EncodeFrame(rawFrame) + if err != nil { + logger.Warnf("videochannel encoder error: %v", err) + return + } + + _ = p.track.WriteSample(media.Sample{Data: sample, Duration: frameDuration}) +} + func (p *streamTransport) writerLoop() { defer close(p.writerDone) defer func() { @@ -334,58 +377,24 @@ func (p *streamTransport) writerLoop() { continue } - // idle frame: payload is nil — reuse previously encoded sample to avoid - // re-rendering and re-encoding an identical blank frame every tick. if payload == nil { - p.idleFrameMu.Lock() - cached := p.idleFrame - p.idleFrameMu.Unlock() - - if cached == nil { - // first time — render + encode once, then cache - rawFrame, err := renderVisualFrame(nil, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS) - if err != nil { - logger.Debugf("videochannel render idle error: %v", err) - continue - } - sample, err := enc.EncodeFrame(rawFrame) - if err != nil { - logger.Warnf("videochannel encoder idle error: %v", err) - continue - } - p.idleFrameMu.Lock() - p.idleFrame = sample - p.idleFrameMu.Unlock() - cached = sample - } - - _ = p.track.WriteSample(media.Sample{ - Data: cached, - Duration: frameDuration, - }) - continue + p.writeIdleFrame(enc, frameDuration) + } else { + p.writePayloadFrame(enc, payload, frameDuration) } - - rawFrame, err := renderVisualFrame(payload, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS) - if err != nil { - logger.Debugf("videochannel render error: %v", err) - continue - } - - sample, err := enc.EncodeFrame(rawFrame) - if err != nil { - logger.Warnf("videochannel encoder error: %v", err) - continue - } - - _ = p.track.WriteSample(media.Sample{ - Data: sample, - Duration: frameDuration, - }) } } } +func (p *streamTransport) renderFrame(payload []byte) ([]byte, error) { + return renderVisualFrame( + payload, + p.videoW, p.videoH, + p.videoCodec, p.videoQRRecovery, + p.videoTileModule, p.videoTileRS, + ) +} + func (p *streamTransport) nextOutboundFrame() ([]byte, bool) { select { case <-p.closeCh: @@ -425,6 +434,61 @@ func (p *streamTransport) enqueueFrame(frame []byte, priority bool) error { } } +func (p *streamTransport) popDecoderFrames(decoder *ffmpegDecoder) { + defer func() { + p.decoderMu.Lock() + if p.decoder == decoder { + p.decoder = nil + } + p.decoderMu.Unlock() + _ = decoder.Close() + }() + + for { + select { + case <-p.closeCh: + return + default: + } + + frame, err := decoder.PopFrame() + if err != nil { + if !errors.Is(err, ErrTransportClosed) && !p.closed.Load() { + logger.Warnf("videochannel decoder pop error: %v", err) + } + return + } + p.handleFrame(frame) + } +} + +func (p *streamTransport) readDecoderInput(track *webrtc.TrackRemote, decoder *ffmpegDecoder, codec codecSpec) { + sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate) + for { + select { + case <-p.closeCh: + return + default: + } + + packet, _, err := track.ReadRTP() + if err != nil { + sb.Flush() + return + } + + sb.Push(packet) + for sample := sb.Pop(); sample != nil; sample = sb.Pop() { + if err := decoder.PushSample(sample.Data); err != nil { + if !p.closed.Load() { + logger.Warnf("videochannel decoder push error: %v", err) + } + return + } + } + } +} + func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { codec, ok := codecSpecForMime(track.Codec().MimeType) if !ok { @@ -432,7 +496,7 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc return } - decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS, p.videoHW) + decoder, err := newFFmpegDecoder(p.runCtx, codec, p.videoW, p.videoH, p.videoFPS, p.videoHW) if err != nil { logger.Warnf("videochannel decoder init failed: %v", err) return @@ -450,60 +514,8 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc p.decoder = decoder p.decoderMu.Unlock() - go func() { - defer func() { - p.decoderMu.Lock() - if p.decoder == decoder { - p.decoder = nil - } - p.decoderMu.Unlock() - _ = decoder.Close() - }() - - for { - select { - case <-p.closeCh: - return - default: - } - - frame, err := decoder.PopFrame() - if err != nil { - if !errors.Is(err, ErrTransportClosed) && !p.closed.Load() { - logger.Warnf("videochannel decoder pop error: %v", err) - } - return - } - p.handleFrame(frame) - } - }() - - go func() { - sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate) - for { - select { - case <-p.closeCh: - return - default: - } - - packet, _, err := track.ReadRTP() - if err != nil { - sb.Flush() - return - } - - sb.Push(packet) - for sample := sb.Pop(); sample != nil; sample = sb.Pop() { - if err := decoder.PushSample(sample.Data); err != nil { - if !p.closed.Load() { - logger.Warnf("videochannel decoder push error: %v", err) - } - return - } - } - } - }() + go p.popDecoderFrames(decoder) + go p.readDecoderInput(track, decoder, codec) } func (p *streamTransport) handleFrame(frame []byte) { @@ -531,14 +543,7 @@ func (p *streamTransport) handleFrame(frame []byte) { } } -func (p *streamTransport) handleInboundFrame(frame transportFrame) { - p.recvMu.Lock() - if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { - p.recvMu.Unlock() - p.sendAck(frame.seq, frame.crc) - return - } - +func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) { msg, ok := p.inbound[frame.seq] if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) { msg = &inboundMessage{ @@ -549,33 +554,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) { } p.inbound[frame.seq] = msg } - if int(frame.fragIdx) >= len(msg.frags) { - p.recvMu.Unlock() - return + return nil, false } - if msg.frags[frame.fragIdx] == nil { chunk := make([]byte, len(frame.payload)) copy(chunk, frame.payload) msg.frags[frame.fragIdx] = chunk msg.remain-- } + return msg, msg.remain == 0 +} - if msg.remain > 0 { +func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte { + data := make([]byte, 0, msg.totalLen) + for _, frag := range msg.frags { + data = append(data, frag...) + } + if uint32(len(data)) > msg.totalLen { //nolint:gosec + data = data[:msg.totalLen] + } + return data +} + +func (p *streamTransport) handleInboundFrame(frame transportFrame) { + p.recvMu.Lock() + if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc { + p.recvMu.Unlock() + p.sendAck(frame.seq, frame.crc) + return + } + + msg, complete := p.upsertInbound(frame) + if msg == nil || !complete { p.recvMu.Unlock() return } delete(p.inbound, frame.seq) - data := make([]byte, 0, msg.totalLen) - for _, frag := range msg.frags { - data = append(data, frag...) - } - - if uint32(len(data)) > msg.totalLen { - data = data[:msg.totalLen] - } + data := p.assembleMessage(msg) if crc32.ChecksumIEEE(data) != msg.crc { p.recvMu.Unlock() diff --git a/internal/transport/videochannel/visual.go b/internal/transport/videochannel/visual.go index 82491ed..cff3964 100644 --- a/internal/transport/videochannel/visual.go +++ b/internal/transport/videochannel/visual.go @@ -1,6 +1,7 @@ package videochannel import ( + "errors" "fmt" "strings" @@ -8,6 +9,9 @@ import ( grtile "github.com/zarazaex69/gr/tile" ) +// ErrUnexpectedQRFrameSize is returned when the decoded frame size does not match the expected dimensions. +var ErrUnexpectedQRFrameSize = errors.New("unexpected qr frame size") + func eccLevel(level string) grqr.ECCLevel { switch level { case "medium": @@ -21,7 +25,12 @@ func eccLevel(level string) grqr.ECCLevel { } } -func renderVisualFrame(payload []byte, width, height int, codec, recoveryLevel string, tileModule, tileRS int) ([]byte, error) { +func renderVisualFrame( + payload []byte, + width, height int, + codec, recoveryLevel string, + tileModule, tileRS int, +) ([]byte, error) { if codec == "tile" { return renderTileFrame(payload, tileModule, tileRS) } @@ -47,7 +56,11 @@ func renderQRFrame(payload []byte, width, height int, recoveryLevel string) ([]b return nil, fmt.Errorf("qr codec: %w", err) } - return c.Encode(payload) + result, err := c.Encode(payload) + if err != nil { + return nil, fmt.Errorf("qr encode: %w", err) + } + return result, nil } func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) { @@ -64,7 +77,11 @@ func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) { return nil, fmt.Errorf("tile codec: %w", err) } - return c.Encode(payload, 0, 1) + result, err := c.Encode(payload, 0, 1) + if err != nil { + return nil, fmt.Errorf("tile encode: %w", err) + } + return result, nil } func extractVisualPayload(frame []byte, width, height int, codec string, tileModule, tileRS int) ([]byte, error) { @@ -76,7 +93,8 @@ func extractVisualPayload(frame []byte, width, height int, codec string, tileMod func extractQRPayload(frame []byte, width, height int) ([]byte, error) { if len(frame) != width*height { - return nil, fmt.Errorf("unexpected frame size: %d (expected %dx%d=%d)", len(frame), width, height, width*height) + return nil, fmt.Errorf("%w: got %d expected %dx%d=%d", + ErrUnexpectedQRFrameSize, len(frame), width, height, width*height) } c, err := grqr.New(grqr.Config{ @@ -111,7 +129,7 @@ func extractTilePayload(frame []byte, tileModule, tileRS int) ([]byte, error) { result, err := c.Decode(frame) if err != nil { - return nil, nil + return nil, nil //nolint:nilerr } return result.Payload, nil diff --git a/internal/transport/vp8channel/kcp.go b/internal/transport/vp8channel/kcp.go index 3c16ebf..d140999 100644 --- a/internal/transport/vp8channel/kcp.go +++ b/internal/transport/vp8channel/kcp.go @@ -1,3 +1,4 @@ +// Package vp8channel provides byte transport over VP8 video frames using KCP. package vp8channel import ( @@ -58,7 +59,7 @@ type kcpRuntime struct { func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) { c := newKCPConn(out, inboundQueueSize) - sess, err := kcp.NewConn3(kcpConvID, fakeAddr, nil, 0, 0, c) + sess, err := kcp.NewConn3(kcpConvID, fakeUDPAddr(), nil, 0, 0, c) if err != nil { _ = c.Close() return nil, fmt.Errorf("kcp new conn: %w", err) @@ -71,7 +72,6 @@ func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) { sess.SetNoDelay(1, 10, 2, 1) sess.SetWindowSize(kcpSndWnd, kcpRcvWnd) sess.SetMtu(kcpMTU) - sess.SetStreamMode(true) // see kcpLenPrefix comment above sess.SetACKNoDelay(true) sess.SetWriteDelay(false) @@ -127,16 +127,17 @@ func (r *kcpRuntime) send(msg []byte) error { return ErrKCPMessageTooLarge } var hdr [kcpLenPrefix]byte + //nolint:gosec binary.BigEndian.PutUint32(hdr[:], uint32(len(msg))) r.writeMu.Lock() defer r.writeMu.Unlock() if _, err := r.sess.Write(hdr[:]); err != nil { - return err + return fmt.Errorf("kcp write header: %w", err) } if _, err := r.sess.Write(msg); err != nil { - return err + return fmt.Errorf("kcp write payload: %w", err) } return nil } diff --git a/internal/transport/vp8channel/kcpconn.go b/internal/transport/vp8channel/kcpconn.go index 8b76a22..106561b 100644 --- a/internal/transport/vp8channel/kcpconn.go +++ b/internal/transport/vp8channel/kcpconn.go @@ -6,10 +6,9 @@ import ( "time" ) -// fakeAddr is a placeholder address used by the KCP session. The underlying -// "packet conn" is a point-to-point pipe over the VP8 carrier and has no real -// notion of an address, but kcp-go's API requires one. -var fakeAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1} +func fakeUDPAddr() *net.UDPAddr { + return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1} +} // kcpConn is a net.PacketConn implementation that bridges kcp-go on top of // the vp8channel byte-message carrier. @@ -62,7 +61,7 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) { if !deadline.IsZero() { d := time.Until(deadline) if d <= 0 { - return 0, nil, errTimeout{} + return 0, nil, TimeoutError{} } t := time.NewTimer(d) defer t.Stop() @@ -72,11 +71,11 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) { select { case msg := <-c.in: n := copy(p, msg) - return n, fakeAddr, nil + return n, fakeUDPAddr(), nil case <-c.closed: return 0, nil, net.ErrClosed case <-timerC: - return 0, nil, errTimeout{} + return 0, nil, TimeoutError{} } } @@ -92,7 +91,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) { if !deadline.IsZero() { d := time.Until(deadline) if d <= 0 { - return 0, errTimeout{} + return 0, TimeoutError{} } t := time.NewTimer(d) defer t.Stop() @@ -105,7 +104,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) { case <-c.closed: return 0, net.ErrClosed case <-timerC: - return 0, errTimeout{} + return 0, TimeoutError{} } } @@ -114,7 +113,7 @@ func (c *kcpConn) Close() error { return nil } -func (c *kcpConn) LocalAddr() net.Addr { return fakeAddr } +func (c *kcpConn) LocalAddr() net.Addr { return fakeUDPAddr() } func (c *kcpConn) SetDeadline(t time.Time) error { _ = c.SetReadDeadline(t) @@ -136,8 +135,13 @@ func (c *kcpConn) SetWriteDeadline(t time.Time) error { return nil } -type errTimeout struct{} +// TimeoutError is a net.Error indicating a deadline exceeded. +type TimeoutError struct{} -func (errTimeout) Error() string { return "i/o timeout" } -func (errTimeout) Timeout() bool { return true } -func (errTimeout) Temporary() bool { return true } +func (TimeoutError) Error() string { return "i/o timeout" } + +// Timeout reports that this error is a timeout. +func (TimeoutError) Timeout() bool { return true } + +// Temporary reports that this error is temporary. +func (TimeoutError) Temporary() bool { return true } diff --git a/internal/transport/vp8channel/transport.go b/internal/transport/vp8channel/transport.go index 0df9f18b..a10c4ca 100644 --- a/internal/transport/vp8channel/transport.go +++ b/internal/transport/vp8channel/transport.go @@ -27,14 +27,13 @@ const ( ) var ( + // ErrVideoTrackUnsupported is returned when a carrier cannot expose video tracks. ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") - ErrTransportClosed = errors.New("vp8channel transport closed") + // ErrTransportClosed is returned when operations are attempted on a closed transport. + ErrTransportClosed = errors.New("vp8channel transport closed") ) -// vp8Keepalive is a minimal VP8 keyframe used as idle filler so that the SFU -// keeps the track flowing when KCP has nothing to send. It is never delivered -// to KCP because KCP packets always start with the convid (0xC0FFEE01 LE) -// and would never collide with this keyframe payload. +//nolint:gochecknoglobals var vp8Keepalive = []byte{ 0x30, 0x01, 0x00, 0x9d, 0x01, 0x2a, 0x10, 0x00, 0x10, 0x00, 0x00, 0x47, 0x08, 0x85, 0x85, 0x88, @@ -64,6 +63,7 @@ type streamTransport struct { kcpMu sync.RWMutex } +// New creates a vp8channel transport backed by a carrier-specific provider. func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ RoomURL: cfg.RoomURL, @@ -126,7 +126,7 @@ func (p *streamTransport) Connect(ctx context.Context) error { defer cancel() if err := p.stream.Connect(connectCtx); err != nil { - return err + return fmt.Errorf("connect stream: %w", err) } var startErr error @@ -179,7 +179,9 @@ func (p *streamTransport) Close() error { if p.writerUp.Load() { <-p.writerDone } - return p.stream.Close() + if err := p.stream.Close(); err != nil { + return fmt.Errorf("close stream: %w", err) + } } return nil } @@ -302,14 +304,62 @@ func (p *streamTransport) drainTrack(track *webrtc.TrackRemote) { } } -func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) { - var vp8Pkt codecs.VP8Packet - var frameBuf []byte - buf := make([]byte, rtpBufSize) +type vp8FrameState struct { + vp8Pkt codecs.VP8Packet + frameBuf []byte + lastSeq uint16 + haveLastSeq bool + frameValid bool +} - var lastSeq uint16 - var haveLastSeq bool - frameValid := false +// processRTPPacket returns a complete KCP frame when the VP8 frame is fully assembled, nil otherwise. +// Detects packet loss/reordering to avoid silently corrupting fragmented VP8 frames. +func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte { + if s.haveLastSeq && pkt.SequenceNumber != s.lastSeq+1 { + s.frameValid = false + s.frameBuf = s.frameBuf[:0] + } + s.lastSeq = pkt.SequenceNumber + s.haveLastSeq = true + + vp8Payload, err := s.vp8Pkt.Unmarshal(pkt.Payload) + if err != nil { + s.frameValid = false + s.frameBuf = s.frameBuf[:0] + return nil + } + + if s.vp8Pkt.S == 1 { + s.frameBuf = s.frameBuf[:0] + s.frameValid = true + } + + if !s.frameValid { + return nil + } + + s.frameBuf = append(s.frameBuf, vp8Payload...) + + if !pkt.Marker { + return nil + } + + defer func() { + s.frameBuf = s.frameBuf[:0] + s.frameValid = false + }() + + if len(s.frameBuf) >= 4 && s.frameBuf[0] == kcpMagic { + frame := make([]byte, len(s.frameBuf)) + copy(frame, s.frameBuf) + return frame + } + return nil +} + +func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) { + var state vp8FrameState + buf := make([]byte, rtpBufSize) for { n, _, err := track.Read(buf) @@ -322,54 +372,16 @@ func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) { continue } - // Detect packet loss / reordering. A single missing RTP packet - // inside a fragmented VP8 frame would otherwise silently corrupt - // the assembled payload (and bleed into the next frame). KCP can - // recover from full-frame drops, but only if the frames it does - // receive are byte-perfect. - if haveLastSeq { - expected := lastSeq + 1 - if pkt.SequenceNumber != expected { - frameValid = false - frameBuf = frameBuf[:0] - } - } - lastSeq = pkt.SequenceNumber - haveLastSeq = true - - vp8Payload, err := vp8Pkt.Unmarshal(pkt.Payload) - if err != nil { - frameValid = false - frameBuf = frameBuf[:0] + frame := state.processRTPPacket(pkt) + if frame == nil { continue } - if vp8Pkt.S == 1 { - frameBuf = frameBuf[:0] - frameValid = true - } - - if !frameValid { - continue - } - - frameBuf = append(frameBuf, vp8Payload...) - - if pkt.Marker { - if len(frameBuf) >= 4 && frameBuf[0] == kcpMagic { - p.kcpMu.RLock() - rt := p.kcp - p.kcpMu.RUnlock() - if rt != nil { - // Copy out of the shared frame buffer before handing - // the payload off — KCP's deliver path is async. - payload := make([]byte, len(frameBuf)) - copy(payload, frameBuf) - rt.deliver(payload) - } - } - frameBuf = frameBuf[:0] - frameValid = false + p.kcpMu.RLock() + rt := p.kcp + p.kcpMu.RUnlock() + if rt != nil { + rt.deliver(frame) } } } diff --git a/internal/transport/vp8channel/transport_test.go b/internal/transport/vp8channel/transport_test.go index f8d7021..33feb2b 100644 --- a/internal/transport/vp8channel/transport_test.go +++ b/internal/transport/vp8channel/transport_test.go @@ -7,16 +7,64 @@ import ( "time" ) +func pumpPackets(stop <-chan struct{}, from <-chan []byte, to *kcpRuntime) { + for { + select { + case <-stop: + return + case pkt := <-from: + to.deliver(pkt) + } + } +} + +func checkMessages(t *testing.T, got, want [][]byte) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("got %d messages, want %d", len(got), len(want)) + } + for i, m := range want { + if !bytes.Equal(got[i], m) { + t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(got[i]), len(m)) + } + } +} + +func buildReceiver(n int) (func([]byte), <-chan struct{}, func() [][]byte) { + var mu sync.Mutex + var recv [][]byte + done := make(chan struct{}) + cb := func(msg []byte) { + mu.Lock() + recv = append(recv, append([]byte(nil), msg...)) + count := len(recv) + mu.Unlock() + if count == n { + close(done) + } + } + get := func() [][]byte { + mu.Lock() + defer mu.Unlock() + return recv + } + return cb, done, get +} + // TestKCPLoopback runs two KCP runtimes back-to-back through an in-memory // pipe simulating a perfect carrier. Verifies that messages survive the // KCP layer with their boundaries intact. func TestKCPLoopback(t *testing.T) { + msgs := [][]byte{ + []byte("hello"), + bytes.Repeat([]byte("x"), 1000), + bytes.Repeat([]byte("y"), 20000), + } + a2b := make(chan []byte, 256) b2a := make(chan []byte, 256) - var bRecvMu sync.Mutex - var bRecv [][]byte - doneB := make(chan struct{}) + cb, doneB, getRecv := buildReceiver(len(msgs)) rtA, err := startKCP(a2b, nil) if err != nil { @@ -24,50 +72,18 @@ func TestKCPLoopback(t *testing.T) { } defer rtA.close() - rtB, err := startKCP(b2a, func(msg []byte) { - bRecvMu.Lock() - bRecv = append(bRecv, append([]byte(nil), msg...)) - n := len(bRecv) - bRecvMu.Unlock() - if n == 3 { - close(doneB) - } - }) + rtB, err := startKCP(b2a, cb) if err != nil { t.Fatalf("startKCP B: %v", err) } defer rtB.close() - // Pump packets between the two runtimes. stop := make(chan struct{}) defer close(stop) - go func() { - for { - select { - case <-stop: - return - case pkt := <-a2b: - rtB.deliver(pkt) - } - } - }() - go func() { - for { - select { - case <-stop: - return - case pkt := <-b2a: - rtA.deliver(pkt) - } - } - }() + go pumpPackets(stop, a2b, rtB) + go pumpPackets(stop, b2a, rtA) - msgs := [][]byte{ - []byte("hello"), - bytes.Repeat([]byte("x"), 1000), - bytes.Repeat([]byte("y"), 20000), - } for _, m := range msgs { if err := rtA.send(m); err != nil { t.Fatalf("send: %v", err) @@ -80,21 +96,10 @@ func TestKCPLoopback(t *testing.T) { t.Fatal("timeout waiting for messages") } - bRecvMu.Lock() - defer bRecvMu.Unlock() - if len(bRecv) != len(msgs) { - t.Fatalf("got %d messages, want %d", len(bRecv), len(msgs)) - } - for i, m := range msgs { - if !bytes.Equal(bRecv[i], m) { - t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(bRecv[i]), len(m)) - } - } + checkMessages(t, getRecv(), msgs) } func TestVP8KeepaliveDoesNotLookLikeKCP(t *testing.T) { - // Keepalive frames must not be mistaken for KCP packets by the receive - // path; otherwise the KCP stack would constantly chew on garbage. if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpMagic { t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpMagic) }