mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
fix: fix all golangci errors
This commit is contained in:
@@ -3,6 +3,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -16,6 +17,9 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
)
|
||||
|
||||
// ErrDataDirRequired is returned when no data directory is specified.
|
||||
var ErrDataDirRequired = errors.New("data directory required (use -data data)")
|
||||
|
||||
type config struct {
|
||||
mode string
|
||||
link string
|
||||
@@ -59,11 +63,11 @@ func run() error {
|
||||
configureLogging(cfg.debug)
|
||||
|
||||
if err := session.Validate(toSessionConfig(cfg)); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("validate config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.dataDir == "" {
|
||||
return fmt.Errorf("data directory required (use -data data)")
|
||||
return ErrDataDirRequired
|
||||
}
|
||||
|
||||
dataDir, err := resolveDataDir(cfg.dataDir)
|
||||
@@ -119,10 +123,13 @@ func parseFlags() config {
|
||||
flag.StringVar(&cfg.videoBitrate, "video-bitrate", "", "Video bitrate (videochannel only)")
|
||||
flag.StringVar(&cfg.videoHW, "video-hw", "", "Hardware acceleration (none, nvenc)")
|
||||
flag.IntVar(&cfg.videoQRSize, "video-qr-size", 0, "Video QR code fragment size (videochannel only)")
|
||||
flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low", "QR error correction: low (7%), medium (15%), high (25%), highest (30%)")
|
||||
flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low",
|
||||
"QR error correction: low (7%), medium (15%), high (25%), highest (30%)")
|
||||
flag.StringVar(&cfg.videoCodec, "video-codec", "qrcode", "Visual codec: qrcode or tile")
|
||||
flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0, "Tile module size in pixels 1..270 (videochannel tile only, default 4)")
|
||||
flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0, "Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)")
|
||||
flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0,
|
||||
"Tile module size in pixels 1..270 (videochannel tile only, default 4)")
|
||||
flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0,
|
||||
"Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)")
|
||||
flag.IntVar(&cfg.vp8FPS, "vp8-fps", 0, "VP8 frames per second (vp8channel only, default 25)")
|
||||
flag.IntVar(&cfg.vp8BatchSize, "vp8-batch", 0, "VP8 frames per tick (vp8channel only, default 1)")
|
||||
flag.Parse()
|
||||
@@ -161,22 +168,22 @@ func loadNames(dataDir string) error {
|
||||
|
||||
func toSessionConfig(cfg config) session.Config {
|
||||
return session.Config{
|
||||
Mode: cfg.mode,
|
||||
Link: cfg.link,
|
||||
Transport: cfg.transport,
|
||||
Carrier: firstNonEmpty(cfg.carrier, cfg.provider),
|
||||
RoomID: cfg.roomID,
|
||||
KeyHex: cfg.keyHex,
|
||||
SOCKSHost: cfg.socksHost,
|
||||
SOCKSPort: cfg.socksPort,
|
||||
DNSServer: cfg.dnsServer,
|
||||
SOCKSProxyAddr: cfg.socksProxyAddr,
|
||||
SOCKSProxyPort: cfg.socksProxyPort,
|
||||
VideoWidth: cfg.videoWidth,
|
||||
VideoHeight: cfg.videoHeight,
|
||||
VideoFPS: cfg.videoFPS,
|
||||
VideoBitrate: cfg.videoBitrate,
|
||||
VideoHW: cfg.videoHW,
|
||||
Mode: cfg.mode,
|
||||
Link: cfg.link,
|
||||
Transport: cfg.transport,
|
||||
Carrier: firstNonEmpty(cfg.carrier, cfg.provider),
|
||||
RoomID: cfg.roomID,
|
||||
KeyHex: cfg.keyHex,
|
||||
SOCKSHost: cfg.socksHost,
|
||||
SOCKSPort: cfg.socksPort,
|
||||
DNSServer: cfg.dnsServer,
|
||||
SOCKSProxyAddr: cfg.socksProxyAddr,
|
||||
SOCKSProxyPort: cfg.socksProxyPort,
|
||||
VideoWidth: cfg.videoWidth,
|
||||
VideoHeight: cfg.videoHeight,
|
||||
VideoFPS: cfg.videoFPS,
|
||||
VideoBitrate: cfg.videoBitrate,
|
||||
VideoHW: cfg.videoHW,
|
||||
VideoQRSize: cfg.videoQRSize,
|
||||
VideoQRRecovery: cfg.videoQRRecovery,
|
||||
VideoCodec: cfg.videoCodec,
|
||||
|
||||
@@ -19,13 +19,19 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/transport/vp8channel"
|
||||
)
|
||||
|
||||
const (
|
||||
modeSRV = "srv"
|
||||
modeCNC = "cnc"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
|
||||
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
|
||||
// ErrModeRequired indicates that mode is not one of the supported values.
|
||||
ErrModeRequired = errors.New("mode required (use -mode srv or -mode cnc)")
|
||||
// ErrCarrierRequired indicates that no carrier was selected.
|
||||
ErrCarrierRequired = errors.New("carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)")
|
||||
ErrCarrierRequired = errors.New(
|
||||
"carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)")
|
||||
// ErrUnsupportedCarrier indicates that carrier is not registered.
|
||||
ErrUnsupportedCarrier = errors.New("unsupported carrier")
|
||||
// ErrUnsupportedLink indicates that link is not registered.
|
||||
@@ -36,26 +42,40 @@ var (
|
||||
// ErrLinkRequired indicates that link is not provided.
|
||||
ErrLinkRequired = errors.New("link required (use -link direct)")
|
||||
// ErrTransportRequired indicates that transport is not provided.
|
||||
ErrTransportRequired = errors.New("transport required (use -transport datachannel, -transport videochannel, -transport seichannel or -transport vp8channel)")
|
||||
ErrTransportRequired = errors.New(
|
||||
"transport required (use -transport datachannel, -transport videochannel, " +
|
||||
"-transport seichannel or -transport vp8channel)")
|
||||
// ErrKeyRequired indicates that encryption key is not provided.
|
||||
ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
// ErrDNSServerRequired indicates that dns server is not provided.
|
||||
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
|
||||
|
||||
// Videochannel errors
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
|
||||
ErrVideoBitrateRequired = errors.New("video bitrate required for videochannel (use -video-bitrate)")
|
||||
ErrVideoHWRequired = errors.New("video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
|
||||
ErrVideoCodecInvalid = errors.New("invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
|
||||
// ErrVideoWidthRequired indicates that video width is required for videochannel.
|
||||
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
|
||||
// ErrVideoHeightRequired indicates that video height is required for videochannel.
|
||||
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
|
||||
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
|
||||
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
|
||||
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
|
||||
ErrVideoBitrateRequired = errors.New(
|
||||
"video bitrate required for videochannel (use -video-bitrate)")
|
||||
// ErrVideoHWRequired indicates that video hardware acceleration is required.
|
||||
ErrVideoHWRequired = errors.New(
|
||||
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
|
||||
// ErrVideoCodecInvalid indicates that the video codec is not valid.
|
||||
ErrVideoCodecInvalid = errors.New(
|
||||
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
|
||||
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
|
||||
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
|
||||
|
||||
// VP8channel errors
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
|
||||
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
|
||||
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
|
||||
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
|
||||
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
|
||||
|
||||
// CNC errors
|
||||
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
|
||||
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
|
||||
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
|
||||
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
|
||||
)
|
||||
|
||||
@@ -98,115 +118,143 @@ func RegisterDefaults() {
|
||||
|
||||
// Validate verifies that the runtime config refers to registered components and all required fields are present.
|
||||
func Validate(cfg Config) error {
|
||||
availableCarriers := carrier.Available()
|
||||
validCarrier := false
|
||||
for _, c := range availableCarriers {
|
||||
if cfg.Carrier == c {
|
||||
validCarrier = true
|
||||
break
|
||||
}
|
||||
if err := validateMode(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
availableTransports := transport.Available()
|
||||
validTransport := false
|
||||
for _, t := range availableTransports {
|
||||
if cfg.Transport == t {
|
||||
validTransport = true
|
||||
break
|
||||
}
|
||||
if err := validateCarrier(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
availableLinks := link.Available()
|
||||
validLink := false
|
||||
for _, l := range availableLinks {
|
||||
if cfg.Link == l {
|
||||
validLink = true
|
||||
break
|
||||
}
|
||||
if err := validateLink(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTransportRegistration(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateCommon(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTransportConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateModeConfig(cfg)
|
||||
}
|
||||
|
||||
if cfg.Mode == "" {
|
||||
return ErrModeRequired
|
||||
}
|
||||
if cfg.Mode != "srv" && cfg.Mode != "cnc" {
|
||||
func validateMode(cfg Config) error {
|
||||
if cfg.Mode == "" || (cfg.Mode != modeSRV && cfg.Mode != modeCNC) {
|
||||
return ErrModeRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCarrier(cfg Config) error {
|
||||
if cfg.Carrier == "" {
|
||||
return ErrCarrierRequired
|
||||
}
|
||||
if !validCarrier {
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, availableCarriers)
|
||||
for _, c := range carrier.Available() {
|
||||
if cfg.Carrier == c {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, carrier.Available())
|
||||
}
|
||||
|
||||
func validateLink(cfg Config) error {
|
||||
if cfg.Link == "" {
|
||||
return ErrLinkRequired
|
||||
}
|
||||
if !validLink {
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, availableLinks)
|
||||
for _, l := range link.Available() {
|
||||
if cfg.Link == l {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, link.Available())
|
||||
}
|
||||
|
||||
func validateTransportRegistration(cfg Config) error {
|
||||
if cfg.Transport == "" {
|
||||
return ErrTransportRequired
|
||||
}
|
||||
if !validTransport {
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, availableTransports)
|
||||
for _, t := range transport.Available() {
|
||||
if cfg.Transport == t {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, transport.Available())
|
||||
}
|
||||
|
||||
func validateCommon(cfg Config) error {
|
||||
if cfg.RoomID == "" && cfg.Carrier != "jazz" {
|
||||
return ErrRoomIDRequired
|
||||
}
|
||||
|
||||
if cfg.KeyHex == "" {
|
||||
return ErrKeyRequired
|
||||
}
|
||||
|
||||
if cfg.DNSServer == "" {
|
||||
return ErrDNSServerRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Transport == "videochannel" {
|
||||
if cfg.VideoWidth == 0 {
|
||||
return ErrVideoWidthRequired
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
return ErrVideoHeightRequired
|
||||
}
|
||||
if cfg.VideoFPS == 0 {
|
||||
return ErrVideoFPSRequired
|
||||
}
|
||||
if cfg.VideoBitrate == "" {
|
||||
return ErrVideoBitrateRequired
|
||||
}
|
||||
if cfg.VideoHW == "" {
|
||||
return ErrVideoHWRequired
|
||||
}
|
||||
if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" {
|
||||
return ErrVideoCodecInvalid
|
||||
}
|
||||
if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) {
|
||||
return errors.New("tile codec requires -video-w 1080 -video-h 1080")
|
||||
}
|
||||
func validateTransportConfig(cfg Config) error {
|
||||
switch cfg.Transport {
|
||||
case "videochannel":
|
||||
return validateVideoChannel(cfg)
|
||||
case "vp8channel":
|
||||
return validateVP8Channel(cfg)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Transport == "vp8channel" {
|
||||
if cfg.VP8FPS == 0 {
|
||||
return ErrVP8FPSRequired
|
||||
}
|
||||
if cfg.VP8BatchSize == 0 {
|
||||
return ErrVP8BatchSizeRequired
|
||||
}
|
||||
func validateVideoCodec(cfg Config) error {
|
||||
if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" {
|
||||
return ErrVideoCodecInvalid
|
||||
}
|
||||
|
||||
if cfg.Mode == "cnc" {
|
||||
if cfg.SOCKSHost == "" {
|
||||
return ErrSOCKSHostRequired
|
||||
}
|
||||
if cfg.SOCKSPort == 0 {
|
||||
return ErrSOCKSPortRequired
|
||||
}
|
||||
if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) {
|
||||
return ErrTileCodecDimensions
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateVideoChannel(cfg Config) error {
|
||||
if cfg.VideoWidth == 0 {
|
||||
return ErrVideoWidthRequired
|
||||
}
|
||||
if cfg.VideoHeight == 0 {
|
||||
return ErrVideoHeightRequired
|
||||
}
|
||||
if cfg.VideoFPS == 0 {
|
||||
return ErrVideoFPSRequired
|
||||
}
|
||||
if cfg.VideoBitrate == "" {
|
||||
return ErrVideoBitrateRequired
|
||||
}
|
||||
if cfg.VideoHW == "" {
|
||||
return ErrVideoHWRequired
|
||||
}
|
||||
return validateVideoCodec(cfg)
|
||||
}
|
||||
|
||||
func validateVP8Channel(cfg Config) error {
|
||||
if cfg.VP8FPS == 0 {
|
||||
return ErrVP8FPSRequired
|
||||
}
|
||||
if cfg.VP8BatchSize == 0 {
|
||||
return ErrVP8BatchSizeRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateModeConfig(cfg Config) error {
|
||||
if cfg.Mode != modeCNC {
|
||||
return nil
|
||||
}
|
||||
if cfg.SOCKSHost == "" {
|
||||
return ErrSOCKSHostRequired
|
||||
}
|
||||
if cfg.SOCKSPort == 0 {
|
||||
return ErrSOCKSPortRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -215,8 +263,8 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
roomURL := buildRoomURL(cfg.Carrier, cfg.RoomID)
|
||||
|
||||
switch cfg.Mode {
|
||||
case "srv":
|
||||
return server.Run(
|
||||
case modeSRV:
|
||||
if err := server.Run(
|
||||
ctx,
|
||||
cfg.Link,
|
||||
cfg.Transport,
|
||||
@@ -238,9 +286,12 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
cfg.VideoTileRS,
|
||||
cfg.VP8FPS,
|
||||
cfg.VP8BatchSize,
|
||||
)
|
||||
case "cnc":
|
||||
return client.Run(
|
||||
); err != nil {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case modeCNC:
|
||||
if err := client.Run(
|
||||
ctx,
|
||||
cfg.Link,
|
||||
cfg.Transport,
|
||||
@@ -263,7 +314,10 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
cfg.VideoTileRS,
|
||||
cfg.VP8FPS,
|
||||
cfg.VP8BatchSize,
|
||||
)
|
||||
); err != nil {
|
||||
return fmt.Errorf("client: %w", err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return ErrModeRequired
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package carrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/provider"
|
||||
"github.com/pion/webrtc/v4"
|
||||
@@ -32,6 +33,11 @@ type VideoTrack interface {
|
||||
SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver))
|
||||
}
|
||||
|
||||
type videoTrackProvider interface {
|
||||
provider.Provider
|
||||
provider.VideoTrackCapable
|
||||
}
|
||||
|
||||
type legacySession struct {
|
||||
provider provider.Provider
|
||||
}
|
||||
@@ -39,7 +45,7 @@ type legacySession struct {
|
||||
// Capabilities reports the transport primitives supported by the legacy carrier.
|
||||
func (s *legacySession) Capabilities() Capabilities {
|
||||
caps := Capabilities{ByteStream: true}
|
||||
_, caps.VideoTrack = s.provider.(provider.VideoTrackCapable)
|
||||
_, caps.VideoTrack = s.provider.(videoTrackProvider)
|
||||
return caps
|
||||
}
|
||||
|
||||
@@ -50,20 +56,35 @@ func (s *legacySession) OpenByteStream() (ByteStream, error) {
|
||||
|
||||
// OpenVideoTrack adapts a legacy provider to the generic video track capability.
|
||||
func (s *legacySession) OpenVideoTrack() (VideoTrack, error) {
|
||||
publisher, ok := s.provider.(provider.VideoTrackCapable)
|
||||
vtp, ok := s.provider.(videoTrackProvider)
|
||||
if !ok {
|
||||
return nil, ErrVideoTrackUnsupported
|
||||
}
|
||||
return &legacyVideoTrack{provider: publisher}, nil
|
||||
return &legacyVideoTrack{provider: vtp}, nil
|
||||
}
|
||||
|
||||
type legacyByteStream struct {
|
||||
provider provider.Provider
|
||||
}
|
||||
|
||||
func (p *legacyByteStream) Connect(ctx context.Context) error { return p.provider.Connect(ctx) }
|
||||
func (p *legacyByteStream) Send(data []byte) error { return p.provider.Send(data) }
|
||||
func (p *legacyByteStream) Close() error { return p.provider.Close() }
|
||||
func (p *legacyByteStream) Connect(ctx context.Context) error {
|
||||
if err := p.provider.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (p *legacyByteStream) Send(data []byte) error {
|
||||
if err := p.provider.Send(data); err != nil {
|
||||
return fmt.Errorf("send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (p *legacyByteStream) Close() error {
|
||||
if err := p.provider.Close(); err != nil {
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *legacyByteStream) SetReconnectCallback(cb func()) {
|
||||
p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) {
|
||||
@@ -81,31 +102,38 @@ func (p *legacyByteStream) WatchConnection(ctx context.Context) {
|
||||
func (p *legacyByteStream) CanSend() bool { return p.provider.CanSend() }
|
||||
|
||||
type legacyVideoTrack struct {
|
||||
provider provider.VideoTrackCapable
|
||||
provider videoTrackProvider
|
||||
}
|
||||
|
||||
func (v *legacyVideoTrack) Connect(ctx context.Context) error {
|
||||
return v.provider.(provider.Provider).Connect(ctx)
|
||||
if err := v.provider.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *legacyVideoTrack) Close() error { return v.provider.(provider.Provider).Close() }
|
||||
func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) {
|
||||
v.provider.(provider.Provider).SetShouldReconnect(fn)
|
||||
}
|
||||
func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) {
|
||||
v.provider.(provider.Provider).SetEndedCallback(cb)
|
||||
func (v *legacyVideoTrack) Close() error {
|
||||
if err := v.provider.Close(); err != nil {
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) { v.provider.SetShouldReconnect(fn) }
|
||||
func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) { v.provider.SetEndedCallback(cb) }
|
||||
func (v *legacyVideoTrack) WatchConnection(ctx context.Context) {
|
||||
v.provider.(provider.Provider).WatchConnection(ctx)
|
||||
v.provider.WatchConnection(ctx)
|
||||
}
|
||||
func (v *legacyVideoTrack) CanSend() bool { return v.provider.(provider.Provider).CanSend() }
|
||||
func (v *legacyVideoTrack) CanSend() bool { return v.provider.CanSend() }
|
||||
func (v *legacyVideoTrack) AddTrack(track webrtc.TrackLocal) error {
|
||||
return v.provider.AddVideoTrack(track)
|
||||
if err := v.provider.AddVideoTrack(track); err != nil {
|
||||
return fmt.Errorf("add track: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *legacyVideoTrack) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) {
|
||||
v.provider.SetVideoTrackHandler(cb)
|
||||
}
|
||||
func (v *legacyVideoTrack) SetReconnectCallback(cb func()) {
|
||||
v.provider.(provider.Provider).SetReconnectCallback(func(_ *webrtc.DataChannel) {
|
||||
v.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) {
|
||||
if cb != nil {
|
||||
cb()
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ type Config struct {
|
||||
// Factory creates a new carrier session.
|
||||
type Factory func(ctx context.Context, cfg Config) (Session, error)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var registry = make(map[string]Factory)
|
||||
|
||||
// Register adds a carrier factory to the registry.
|
||||
|
||||
@@ -26,15 +26,25 @@ var (
|
||||
ErrConnectFailed = errors.New("tunnel connection failed")
|
||||
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
|
||||
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
|
||||
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
|
||||
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
|
||||
ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command")
|
||||
// ErrUnsupportedAddressType is returned for unsupported SOCKS address types.
|
||||
ErrUnsupportedAddressType = errors.New("unsupported address type")
|
||||
// ErrRemoteNotReady is returned when the server-side stream fails to signal readiness.
|
||||
ErrRemoteNotReady = errors.New("remote not ready")
|
||||
)
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them to the server.
|
||||
type Client struct {
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
sessMu sync.RWMutex
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
sessMu sync.RWMutex
|
||||
dnsServer string
|
||||
}
|
||||
|
||||
@@ -63,7 +73,13 @@ func Run(
|
||||
vp8FPS int,
|
||||
vp8BatchSize int,
|
||||
) error {
|
||||
return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize)
|
||||
return RunWithReady(
|
||||
ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr,
|
||||
dnsServer, socksUser, socksPass, nil,
|
||||
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
|
||||
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
|
||||
vp8FPS, vp8BatchSize,
|
||||
)
|
||||
}
|
||||
|
||||
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
|
||||
@@ -118,7 +134,7 @@ func RunWithReady(
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
defer func() { _ = listener.Close() }()
|
||||
|
||||
logger.Infof("SOCKS5 server listening on %s", localAddr)
|
||||
|
||||
@@ -126,17 +142,10 @@ func RunWithReady(
|
||||
onReady()
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.acceptLoop(runCtx, listener)
|
||||
}()
|
||||
go c.acceptLoop(runCtx, listener)
|
||||
|
||||
select {
|
||||
case <-runCtx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
<-runCtx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) bringUpLink(
|
||||
@@ -227,8 +236,6 @@ func (c *Client) handleReconnect() {
|
||||
c.conn = nil
|
||||
}
|
||||
c.sessMu.Unlock()
|
||||
// New SOCKS5 connections will fail until the link comes back up; the
|
||||
// caller will reissue them. Existing streams die with the smux session.
|
||||
c.conn = muxconn.New(c.ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
if err != nil {
|
||||
@@ -260,7 +267,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
|
||||
return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
@@ -279,13 +286,13 @@ func (c *Client) onData(data []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error {
|
||||
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
return
|
||||
default:
|
||||
logger.Warnf("Accept error: %v", err)
|
||||
continue
|
||||
@@ -295,8 +302,8 @@ func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
func (c *Client) handleSocks5(_ context.Context, conn net.Conn) {
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if err := c.socks5Handshake(conn); err != nil {
|
||||
return
|
||||
@@ -315,38 +322,25 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
c.tunnel(conn, sess, targetAddr, targetPort)
|
||||
}
|
||||
|
||||
func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) {
|
||||
stream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
logger.Warnf("OpenStream failed: %v", err)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
|
||||
|
||||
connectReq, _ := json.Marshal(map[string]any{
|
||||
"cmd": "connect",
|
||||
"addr": targetAddr,
|
||||
"port": targetPort,
|
||||
})
|
||||
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if _, err := stream.Write(connectReq); err != nil {
|
||||
logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err)
|
||||
if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil {
|
||||
logger.Warnf("sid=%d connect failed: %v", stream.ID(), err)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
_ = stream.SetWriteDeadline(time.Time{})
|
||||
|
||||
ack := make([]byte, 1)
|
||||
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
|
||||
logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if _, err := conn.Write(replySuccess()); err != nil {
|
||||
return
|
||||
@@ -357,24 +351,47 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
_, _ = io.Copy(conn, stream)
|
||||
}
|
||||
|
||||
_ = ctx // keep signature
|
||||
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
|
||||
connectReq, err := json.Marshal(map[string]any{
|
||||
"cmd": "connect",
|
||||
"addr": targetAddr,
|
||||
"port": targetPort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
|
||||
}
|
||||
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if _, err := stream.Write(connectReq); err != nil {
|
||||
return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err)
|
||||
}
|
||||
_ = stream.SetWriteDeadline(time.Time{})
|
||||
|
||||
ack := make([]byte, 1)
|
||||
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
|
||||
return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack)
|
||||
}
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read socks5 header: %w", err)
|
||||
}
|
||||
if buf[0] != 5 {
|
||||
return fmt.Errorf("invalid socks version: %d", buf[0])
|
||||
return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0])
|
||||
}
|
||||
methods := make([]byte, buf[1])
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read socks5 methods: %w", err)
|
||||
}
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("write socks5 auth: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -382,43 +399,49 @@ func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", 0, err
|
||||
return "", 0, fmt.Errorf("read socks5 request: %w", err)
|
||||
}
|
||||
if header[1] != 1 {
|
||||
return "", 0, fmt.Errorf("unsupported socks command: %d", header[1])
|
||||
return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1])
|
||||
}
|
||||
|
||||
var addr string
|
||||
switch header[3] {
|
||||
case 1: // IPv4
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = net.IP(buf).String()
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
buf := make([]byte, lenBuf[0])
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = string(buf)
|
||||
default:
|
||||
return "", 0, fmt.Errorf("unsupported address type: %d", header[3])
|
||||
addr, err := c.readSocks5Addr(conn, header[3])
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
return "", 0, err
|
||||
return "", 0, fmt.Errorf("read socks5 port: %w", err)
|
||||
}
|
||||
port := int(binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
|
||||
switch addrType {
|
||||
case 1: // IPv4
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", fmt.Errorf("read socks5 ipv4: %w", err)
|
||||
}
|
||||
return net.IP(buf).String(), nil
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", fmt.Errorf("read socks5 domain len: %w", err)
|
||||
}
|
||||
buf := make([]byte, lenBuf[0])
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", fmt.Errorf("read socks5 domain: %w", err)
|
||||
}
|
||||
return string(buf), nil
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType)
|
||||
}
|
||||
}
|
||||
|
||||
func replySuccess() []byte {
|
||||
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
@@ -16,25 +16,25 @@ type directLink struct {
|
||||
// New creates a direct link that forwards bytes to the selected transport.
|
||||
func New(ctx context.Context, cfg link.Config) (link.Link, error) {
|
||||
tr, err := transport.New(ctx, cfg.Transport, transport.Config{
|
||||
Carrier: cfg.Carrier,
|
||||
RoomURL: cfg.RoomURL,
|
||||
Name: cfg.Name,
|
||||
OnData: cfg.OnData,
|
||||
DNSServer: cfg.DNSServer,
|
||||
ProxyAddr: cfg.ProxyAddr,
|
||||
ProxyPort: cfg.ProxyPort,
|
||||
VideoWidth: cfg.VideoWidth,
|
||||
VideoHeight: cfg.VideoHeight,
|
||||
VideoFPS: cfg.VideoFPS,
|
||||
VideoBitrate: cfg.VideoBitrate,
|
||||
VideoHW: cfg.VideoHW,
|
||||
Carrier: cfg.Carrier,
|
||||
RoomURL: cfg.RoomURL,
|
||||
Name: cfg.Name,
|
||||
OnData: cfg.OnData,
|
||||
DNSServer: cfg.DNSServer,
|
||||
ProxyAddr: cfg.ProxyAddr,
|
||||
ProxyPort: cfg.ProxyPort,
|
||||
VideoWidth: cfg.VideoWidth,
|
||||
VideoHeight: cfg.VideoHeight,
|
||||
VideoFPS: cfg.VideoFPS,
|
||||
VideoBitrate: cfg.VideoBitrate,
|
||||
VideoHW: cfg.VideoHW,
|
||||
VideoQRSize: cfg.VideoQRSize,
|
||||
VideoQRRecovery: cfg.VideoQRRecovery,
|
||||
VideoCodec: cfg.VideoCodec,
|
||||
VideoTileModule: cfg.VideoTileModule,
|
||||
VideoTileRS: cfg.VideoTileRS,
|
||||
VP8FPS: cfg.VP8FPS,
|
||||
VP8BatchSize: cfg.VP8BatchSize,
|
||||
VP8FPS: cfg.VP8FPS,
|
||||
VP8BatchSize: cfg.VP8BatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create transport for direct link: %w", err)
|
||||
@@ -43,9 +43,27 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
|
||||
return &directLink{transport: tr}, nil
|
||||
}
|
||||
|
||||
func (d *directLink) Connect(ctx context.Context) error { return d.transport.Connect(ctx) }
|
||||
func (d *directLink) Send(data []byte) error { return d.transport.Send(data) }
|
||||
func (d *directLink) Close() error { return d.transport.Close() }
|
||||
func (d *directLink) Connect(ctx context.Context) error {
|
||||
if err := d.transport.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("transport connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *directLink) Send(data []byte) error {
|
||||
if err := d.transport.Send(data); err != nil {
|
||||
return fmt.Errorf("transport send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *directLink) Close() error {
|
||||
if err := d.transport.Close(); err != nil {
|
||||
return fmt.Errorf("transport close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *directLink) SetReconnectCallback(cb func()) { d.transport.SetReconnectCallback(cb) }
|
||||
func (d *directLink) SetShouldReconnect(fn func() bool) { d.transport.SetShouldReconnect(fn) }
|
||||
func (d *directLink) SetEndedCallback(cb func(string)) { d.transport.SetEndedCallback(cb) }
|
||||
|
||||
@@ -50,6 +50,7 @@ type Config struct {
|
||||
// Factory creates a link instance.
|
||||
type Factory func(ctx context.Context, cfg Config) (Link, error)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var registry = make(map[string]Factory)
|
||||
|
||||
// Register adds a link factory to the registry.
|
||||
|
||||
@@ -17,6 +17,7 @@ package muxconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -92,10 +93,10 @@ func (c *Conn) Write(p []byte) (int, error) {
|
||||
|
||||
enc, err := c.cipher.Encrypt(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("encrypt: %w", err)
|
||||
}
|
||||
if err := c.ln.Send(enc); err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("send: %w", err)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,13 @@ const (
|
||||
sendDelay = 2 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
// ErrSubscriberMediaTimeout is returned when the subscriber media is not ready within the timeout period.
|
||||
ErrSubscriberMediaTimeout = errors.New("subscriber media timeout")
|
||||
)
|
||||
|
||||
// Peer represents a SaluteJazz WebRTC connection.
|
||||
type Peer struct {
|
||||
name string
|
||||
@@ -135,23 +142,23 @@ func (p *Peer) attachPendingVideoTracks() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect starts the WebRTC connection process.
|
||||
func (p *Peer) Connect(ctx context.Context) error {
|
||||
p.closed.Store(false)
|
||||
p.resetMediaState()
|
||||
|
||||
config := webrtc.Configuration{
|
||||
func defaultWebRTCConfig() webrtc.Configuration {
|
||||
return webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{},
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
|
||||
BundlePolicy: webrtc.BundlePolicyMaxBundle,
|
||||
}
|
||||
}
|
||||
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
func (p *Peer) buildAPI() *webrtc.API {
|
||||
se := webrtc.SettingEngine{}
|
||||
if protect.Protector != nil {
|
||||
settingEngine.SetICEProxyDialer(protect.NewProxyDialer())
|
||||
se.SetICEProxyDialer(protect.NewProxyDialer())
|
||||
}
|
||||
api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
|
||||
return webrtc.NewAPI(webrtc.WithSettingEngine(se))
|
||||
}
|
||||
|
||||
func (p *Peer) createPeerConnections(api *webrtc.API, config webrtc.Configuration) error {
|
||||
var err error
|
||||
p.pcSub, err = api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
@@ -162,7 +169,6 @@ func (p *Peer) Connect(ctx context.Context) error {
|
||||
if track.Kind() != webrtc.RTPCodecTypeVideo {
|
||||
return
|
||||
}
|
||||
|
||||
if cb := p.videoTrackHandler(); cb != nil {
|
||||
cb(track, receiver)
|
||||
}
|
||||
@@ -173,28 +179,63 @@ func (p *Peer) Connect(ctx context.Context) error {
|
||||
return fmt.Errorf("create publisher pc: %w", err)
|
||||
}
|
||||
p.pcPub.OnConnectionStateChange(p.onPublisherConnectionStateChange)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Peer) createDataChannel() (chan struct{}, error) {
|
||||
var err error
|
||||
p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{
|
||||
Ordered: func() *bool { v := true; return &v }(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create datachannel: %w", err)
|
||||
}
|
||||
dcReady := make(chan struct{})
|
||||
p.setupDataChannelHandlers(dcReady)
|
||||
return dcReady, nil
|
||||
}
|
||||
|
||||
func (p *Peer) waitForReady(ctx context.Context, dcReady chan struct{}) error {
|
||||
if dcReady != nil {
|
||||
select {
|
||||
case <-dcReady:
|
||||
return nil
|
||||
case <-time.After(30 * time.Second):
|
||||
return provider.ErrDataChannelTimeout
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("connect cancelled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
return p.waitForMediaReady(ctx, 30*time.Second)
|
||||
}
|
||||
|
||||
// Connect starts the WebRTC connection process.
|
||||
func (p *Peer) Connect(ctx context.Context) error {
|
||||
p.closed.Store(false)
|
||||
p.resetMediaState()
|
||||
|
||||
api := p.buildAPI()
|
||||
config := defaultWebRTCConfig()
|
||||
|
||||
if err := p.createPeerConnections(api, config); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.attachPendingVideoTracks(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var dcReady chan struct{}
|
||||
if p.onData != nil {
|
||||
p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{
|
||||
Ordered: func() *bool { v := true; return &v }(),
|
||||
})
|
||||
var err error
|
||||
dcReady, err = p.createDataChannel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create datachannel: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
dcReady = make(chan struct{})
|
||||
p.setupDataChannelHandlers(dcReady)
|
||||
}
|
||||
|
||||
if err := p.dialWebSocket(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.sendJoin(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -205,18 +246,7 @@ func (p *Peer) Connect(ctx context.Context) error {
|
||||
p.handleSignaling(ctx)
|
||||
}()
|
||||
|
||||
if p.onData != nil {
|
||||
select {
|
||||
case <-dcReady:
|
||||
return nil
|
||||
case <-time.After(30 * time.Second):
|
||||
return provider.ErrDataChannelTimeout
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("connect cancelled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
return p.waitForMediaReady(ctx, 30*time.Second)
|
||||
return p.waitForReady(ctx, dcReady)
|
||||
}
|
||||
|
||||
func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) error {
|
||||
@@ -226,7 +256,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err
|
||||
select {
|
||||
case <-p.subscriberConn:
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("subscriber media timeout")
|
||||
return ErrSubscriberMediaTimeout
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("connect cancelled: %w", ctx.Err())
|
||||
}
|
||||
@@ -320,30 +350,38 @@ func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}) {
|
||||
}
|
||||
|
||||
func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) {
|
||||
if state == webrtc.PeerConnectionStateConnected {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
p.subscriberReady.Store(true)
|
||||
closeSignal(p.subscriberConn)
|
||||
} else if state == webrtc.PeerConnectionStateDisconnected ||
|
||||
state == webrtc.PeerConnectionStateFailed ||
|
||||
state == webrtc.PeerConnectionStateClosed {
|
||||
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed:
|
||||
p.subscriberReady.Store(false)
|
||||
if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) {
|
||||
if !p.closed.Load() {
|
||||
p.queueReconnect()
|
||||
}
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
p.subscriberReady.Store(false)
|
||||
case webrtc.PeerConnectionStateUnknown,
|
||||
webrtc.PeerConnectionStateNew,
|
||||
webrtc.PeerConnectionStateConnecting:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) {
|
||||
if state == webrtc.PeerConnectionStateConnected {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
p.publisherReady.Store(true)
|
||||
closeSignal(p.publisherConn)
|
||||
} else if state == webrtc.PeerConnectionStateDisconnected ||
|
||||
state == webrtc.PeerConnectionStateFailed ||
|
||||
state == webrtc.PeerConnectionStateClosed {
|
||||
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed:
|
||||
p.publisherReady.Store(false)
|
||||
if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) {
|
||||
if !p.closed.Load() {
|
||||
p.queueReconnect()
|
||||
}
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
p.publisherReady.Store(false)
|
||||
case webrtc.PeerConnectionStateUnknown,
|
||||
webrtc.PeerConnectionStateNew,
|
||||
webrtc.PeerConnectionStateConnecting:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -651,11 +689,6 @@ func (p *Peer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
)
|
||||
|
||||
// AddVideoTrack adds a video track to the publisher peer connection.
|
||||
func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error {
|
||||
p.videoTrackMu.Lock()
|
||||
|
||||
@@ -22,8 +22,6 @@ var (
|
||||
)
|
||||
|
||||
// Provider defines the standard interface for WebRTC connection handlers.
|
||||
//
|
||||
//nolint:interfacebloat // All methods are necessary for provider abstraction.
|
||||
type Provider interface {
|
||||
Connect(ctx context.Context) error
|
||||
Send(data []byte) error
|
||||
|
||||
@@ -42,6 +42,8 @@ var (
|
||||
ErrSessionClosed = errors.New("session closed")
|
||||
// ErrPeerClosed is returned when the peer is closed.
|
||||
ErrPeerClosed = errors.New("peer closed")
|
||||
// ErrSubscriberMediaTimeout is returned when subscriber media is not ready within the timeout period.
|
||||
ErrSubscriberMediaTimeout = errors.New("subscriber media timeout")
|
||||
)
|
||||
|
||||
// TrafficShape defines the parameters for outgoing traffic control.
|
||||
@@ -288,7 +290,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err
|
||||
select {
|
||||
case <-p.subscriberConn:
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("subscriber media timeout")
|
||||
return ErrSubscriberMediaTimeout
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("connect context cancelled: %w", ctx.Err())
|
||||
}
|
||||
@@ -314,7 +316,8 @@ func (p *Peer) setupPeerConnections(config webrtc.Configuration) error {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("telemost remote video track: codec=%s stream=%s track=%s", track.Codec().MimeType, track.StreamID(), track.ID())
|
||||
logger.Infof("telemost remote video track: codec=%s stream=%s track=%s",
|
||||
track.Codec().MimeType, track.StreamID(), track.ID())
|
||||
|
||||
if cb := p.videoTrackHandler(); cb != nil {
|
||||
cb(track, receiver)
|
||||
@@ -342,29 +345,35 @@ func (p *Peer) onConnectionStateChange(state webrtc.PeerConnectionState) {
|
||||
|
||||
func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) {
|
||||
logger.Debugf("telemost subscriber state: %s", state.String())
|
||||
if state == webrtc.PeerConnectionStateConnected {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
p.subscriberReady.Store(true)
|
||||
closeSignal(p.subscriberConn)
|
||||
} else if state == webrtc.PeerConnectionStateDisconnected ||
|
||||
state == webrtc.PeerConnectionStateFailed ||
|
||||
state == webrtc.PeerConnectionStateClosed {
|
||||
case webrtc.PeerConnectionStateDisconnected,
|
||||
webrtc.PeerConnectionStateFailed,
|
||||
webrtc.PeerConnectionStateClosed:
|
||||
p.subscriberReady.Store(false)
|
||||
case webrtc.PeerConnectionStateUnknown,
|
||||
webrtc.PeerConnectionStateNew,
|
||||
webrtc.PeerConnectionStateConnecting:
|
||||
}
|
||||
|
||||
p.onConnectionStateChange(state)
|
||||
}
|
||||
|
||||
func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) {
|
||||
logger.Debugf("telemost publisher state: %s", state.String())
|
||||
if state == webrtc.PeerConnectionStateConnected {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
p.publisherReady.Store(true)
|
||||
closeSignal(p.publisherConn)
|
||||
} else if state == webrtc.PeerConnectionStateDisconnected ||
|
||||
state == webrtc.PeerConnectionStateFailed ||
|
||||
state == webrtc.PeerConnectionStateClosed {
|
||||
case webrtc.PeerConnectionStateDisconnected,
|
||||
webrtc.PeerConnectionStateFailed,
|
||||
webrtc.PeerConnectionStateClosed:
|
||||
p.publisherReady.Store(false)
|
||||
case webrtc.PeerConnectionStateUnknown,
|
||||
webrtc.PeerConnectionStateNew,
|
||||
webrtc.PeerConnectionStateConnecting:
|
||||
}
|
||||
|
||||
p.onConnectionStateChange(state)
|
||||
}
|
||||
|
||||
@@ -656,7 +665,7 @@ func (p *Peer) sendSetSlots() error {
|
||||
p.wsMu.Lock()
|
||||
defer p.wsMu.Unlock()
|
||||
|
||||
return p.ws.WriteJSON(map[string]interface{}{
|
||||
if err := p.ws.WriteJSON(map[string]interface{}{
|
||||
"uid": uuid.New().String(),
|
||||
"setSlots": map[string]interface{}{
|
||||
"slots": []map[string]int{
|
||||
@@ -670,7 +679,52 @@ func (p *Peer) sendSetSlots() error {
|
||||
"selfViewVisibility": "ON_LOADING_THEN_SHOW",
|
||||
"gridConfig": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
}); err != nil {
|
||||
return fmt.Errorf("write set slots: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNonTURNURL(url string) bool {
|
||||
return url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:")
|
||||
}
|
||||
|
||||
func parseICEURLs(server map[string]interface{}) []string {
|
||||
var urls []string
|
||||
switch rawURLs := server["urls"].(type) {
|
||||
case []interface{}:
|
||||
for _, rawURL := range rawURLs {
|
||||
if url, ok := rawURL.(string); ok && isNonTURNURL(url) {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, url := range rawURLs {
|
||||
if isNonTURNURL(url) {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func parseICEServer(rawServer interface{}) (webrtc.ICEServer, bool) {
|
||||
server, ok := rawServer.(map[string]interface{})
|
||||
if !ok {
|
||||
return webrtc.ICEServer{}, false
|
||||
}
|
||||
urls := parseICEURLs(server)
|
||||
if len(urls) == 0 {
|
||||
return webrtc.ICEServer{}, false
|
||||
}
|
||||
ice := webrtc.ICEServer{URLs: urls}
|
||||
if username, ok := server["username"].(string); ok {
|
||||
ice.Username = username
|
||||
}
|
||||
if credential, ok := server["credential"].(string); ok {
|
||||
ice.Credential = credential
|
||||
}
|
||||
return ice, true
|
||||
}
|
||||
|
||||
func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) {
|
||||
@@ -686,39 +740,9 @@ func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) {
|
||||
|
||||
iceServers := make([]webrtc.ICEServer, 0, len(rawServers))
|
||||
for _, rawServer := range rawServers {
|
||||
server, ok := rawServer.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
if ice, ok := parseICEServer(rawServer); ok {
|
||||
iceServers = append(iceServers, ice)
|
||||
}
|
||||
|
||||
var urls []string
|
||||
switch rawURLs := server["urls"].(type) {
|
||||
case []interface{}:
|
||||
for _, rawURL := range rawURLs {
|
||||
if url, ok := rawURL.(string); ok && url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, url := range rawURLs {
|
||||
if !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(urls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ice := webrtc.ICEServer{URLs: urls}
|
||||
if username, ok := server["username"].(string); ok {
|
||||
ice.Username = username
|
||||
}
|
||||
if credential, ok := server["credential"].(string); ok {
|
||||
ice.Credential = credential
|
||||
}
|
||||
iceServers = append(iceServers, ice)
|
||||
}
|
||||
|
||||
if len(iceServers) == 0 {
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeyRequired is returned when no encryption key is provided.
|
||||
ErrKeyRequired = errors.New("key required (use -key <hex>)")
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
|
||||
@@ -100,17 +102,17 @@ func Run(
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.serve(runCtx)
|
||||
s.serve(runCtx)
|
||||
|
||||
s.shutdown()
|
||||
s.wg.Wait()
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
if keyHex == "" {
|
||||
return nil, errors.New("key required (use -key <hex>)")
|
||||
return nil, ErrKeyRequired
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
@@ -252,10 +254,12 @@ func (s *Server) onData(data []byte) {
|
||||
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
|
||||
// The loop tolerates session bounces (reconnects) by waiting until a fresh
|
||||
// session is installed instead of terminating the server.
|
||||
func (s *Server) serve(ctx context.Context) error {
|
||||
func (s *Server) serve(ctx context.Context) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
s.sessMu.RLock()
|
||||
@@ -264,7 +268,7 @@ func (s *Server) serve(ctx context.Context) error {
|
||||
if sess == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
return
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
@@ -272,10 +276,10 @@ func (s *Server) serve(ctx context.Context) error {
|
||||
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
// Session is torn down (reconnect or close). If we're shutting
|
||||
// down, exit; otherwise wait for a new session and retry.
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
logger.Infof("AcceptStream returned %v — waiting for new session", err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -305,7 +309,7 @@ func (s *Server) shutdown() {
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
|
||||
defer stream.Close()
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
// Read the connect JSON. The client writes the whole JSON in one
|
||||
// stream.Write so it usually arrives intact; tolerate fragmentation
|
||||
@@ -356,7 +360,7 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
|
||||
|
||||
|
||||
@@ -44,17 +44,26 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
|
||||
|
||||
// Connect starts the transport connection.
|
||||
func (p *streamTransport) Connect(ctx context.Context) error {
|
||||
return p.stream.Connect(ctx)
|
||||
if err := p.stream.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("stream connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send transmits data through the transport.
|
||||
func (p *streamTransport) Send(data []byte) error {
|
||||
return p.stream.Send(data)
|
||||
if err := p.stream.Send(data); err != nil {
|
||||
return fmt.Errorf("stream send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close terminates the transport.
|
||||
func (p *streamTransport) Close() error {
|
||||
return p.stream.Close()
|
||||
if err := p.stream.Close(); err != nil {
|
||||
return fmt.Errorf("stream close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReconnectCallback registers reconnect handling.
|
||||
|
||||
@@ -3,11 +3,20 @@ package seichannel
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pion/webrtc/v4/pkg/media/h264reader"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrSEIPayloadTruncated is returned when the SEI payload is shorter than expected.
|
||||
ErrSEIPayloadTruncated = errors.New("sei payload truncated")
|
||||
// ErrSEIValueTruncated is returned when reading a SEI length-value runs past the buffer.
|
||||
ErrSEIValueTruncated = errors.New("sei value truncated")
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
videoSEIUUID = [16]byte{
|
||||
0x5d, 0xc0, 0x3b, 0xa8,
|
||||
@@ -21,19 +30,16 @@ var (
|
||||
baseIDR = mustDecodeHex("6588843a2628000902e0")
|
||||
)
|
||||
|
||||
func buildVideoAccessUnit(payload []byte) ([]byte, error) {
|
||||
func buildVideoAccessUnit(payload []byte) []byte {
|
||||
out := make([]byte, 0, len(baseSPS)+len(basePPS)+len(baseIDR)+64+len(payload))
|
||||
out = appendStartCode(out, baseSPS)
|
||||
out = appendStartCode(out, basePPS)
|
||||
if len(payload) > 0 {
|
||||
sei, err := buildSEINAL(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sei := buildSEINAL(payload)
|
||||
out = appendStartCode(out, sei)
|
||||
}
|
||||
out = appendStartCode(out, baseIDR)
|
||||
return out, nil
|
||||
return out
|
||||
}
|
||||
|
||||
func extractVideoPayloads(accessUnit []byte) ([][]byte, error) {
|
||||
@@ -63,7 +69,7 @@ func extractVideoPayloads(accessUnit []byte) ([][]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildSEINAL(payload []byte) ([]byte, error) {
|
||||
func buildSEINAL(payload []byte) []byte {
|
||||
userData := make([]byte, 0, len(videoSEIUUID)+len(payload))
|
||||
userData = append(userData, videoSEIUUID[:]...)
|
||||
userData = append(userData, payload...)
|
||||
@@ -74,9 +80,11 @@ func buildSEINAL(payload []byte) ([]byte, error) {
|
||||
rbsp = append(rbsp, userData...)
|
||||
rbsp = append(rbsp, 0x80)
|
||||
|
||||
out := []byte{0x06}
|
||||
out = append(out, escapeRBSP(rbsp)...)
|
||||
return out, nil
|
||||
escaped := escapeRBSP(rbsp)
|
||||
out := make([]byte, 0, 1+len(escaped))
|
||||
out = append(out, 0x06)
|
||||
out = append(out, escaped...)
|
||||
return out
|
||||
}
|
||||
|
||||
func extractTransportSEI(rbsp []byte) ([][]byte, error) {
|
||||
@@ -101,7 +109,7 @@ func extractTransportSEI(rbsp []byte) ([][]byte, error) {
|
||||
pos = next
|
||||
|
||||
if pos+payloadSize > len(data) {
|
||||
return nil, fmt.Errorf("sei payload truncated")
|
||||
return nil, ErrSEIPayloadTruncated
|
||||
}
|
||||
|
||||
payload := data[pos : pos+payloadSize]
|
||||
@@ -127,14 +135,14 @@ func appendSEIValue(dst []byte, value int) []byte {
|
||||
dst = append(dst, 0xff)
|
||||
value -= 0xff
|
||||
}
|
||||
return append(dst, byte(value))
|
||||
return append(dst, byte(value)) //nolint:gosec
|
||||
}
|
||||
|
||||
func consumeSEIValue(data []byte, pos int) (int, int, error) {
|
||||
value := 0
|
||||
for {
|
||||
if pos >= len(data) {
|
||||
return 0, pos, fmt.Errorf("sei value truncated")
|
||||
return 0, pos, ErrSEIValueTruncated
|
||||
}
|
||||
b := int(data[pos])
|
||||
pos++
|
||||
@@ -170,11 +178,11 @@ func escapeRBSP(rbsp []byte) []byte {
|
||||
|
||||
func unescapeRBSP(rbsp []byte) []byte {
|
||||
out := make([]byte, 0, len(rbsp))
|
||||
for i := 0; i < len(rbsp); i++ {
|
||||
if i >= 2 && rbsp[i] == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 {
|
||||
for i, b := range rbsp {
|
||||
if i >= 2 && b == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 {
|
||||
continue
|
||||
}
|
||||
out = append(out, rbsp[i])
|
||||
out = append(out, b)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -40,6 +40,18 @@ var (
|
||||
ErrAckTimeout = errors.New("seichannel ack timeout")
|
||||
// ErrTransportClosed is returned when operations are attempted on a closed transport.
|
||||
ErrTransportClosed = errors.New("seichannel transport closed")
|
||||
// ErrFrameTooShort is returned when the received frame is too short to decode.
|
||||
ErrFrameTooShort = errors.New("frame too short")
|
||||
// ErrUnexpectedMagic is returned when the frame magic bytes do not match.
|
||||
ErrUnexpectedMagic = errors.New("unexpected frame magic")
|
||||
// ErrUnexpectedVersion is returned when the frame protocol version does not match.
|
||||
ErrUnexpectedVersion = errors.New("unexpected frame version")
|
||||
// ErrAckTooShort is returned when the ack frame is shorter than expected.
|
||||
ErrAckTooShort = errors.New("ack frame too short")
|
||||
// ErrDataTooShort is returned when the data frame is shorter than expected.
|
||||
ErrDataTooShort = errors.New("data frame too short")
|
||||
// ErrUnexpectedFrameType is returned for unknown frame type bytes.
|
||||
ErrUnexpectedFrameType = errors.New("unexpected frame type")
|
||||
)
|
||||
|
||||
type transportFrame struct {
|
||||
@@ -144,7 +156,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
|
||||
defer cancel()
|
||||
|
||||
if err := p.stream.Connect(connectCtx); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("connect stream: %w", err)
|
||||
}
|
||||
|
||||
p.startWriter.Do(func() {
|
||||
@@ -178,7 +190,7 @@ func (p *streamTransport) Send(data []byte) error {
|
||||
p.ackMu.Unlock()
|
||||
}()
|
||||
|
||||
for attempt := 0; attempt < maxSendAttempts; attempt++ {
|
||||
for range maxSendAttempts {
|
||||
for idx, fragment := range fragments {
|
||||
frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment)
|
||||
if err := p.enqueueFrame(frame, false); err != nil {
|
||||
@@ -210,7 +222,9 @@ func (p *streamTransport) Close() error {
|
||||
if p.writerUp.Load() {
|
||||
<-p.writerDone
|
||||
}
|
||||
return p.stream.Close()
|
||||
if err := p.stream.Close(); err != nil {
|
||||
return fmt.Errorf("close stream: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -256,10 +270,7 @@ func (p *streamTransport) writerLoop() {
|
||||
ticker := time.NewTicker(defaultFrameInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
idle, err := buildVideoAccessUnit(nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
idle := buildVideoAccessUnit(nil)
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -273,10 +284,7 @@ func (p *streamTransport) writerLoop() {
|
||||
|
||||
sample := idle
|
||||
if payload != nil {
|
||||
sample, err = buildVideoAccessUnit(payload)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
sample = buildVideoAccessUnit(payload)
|
||||
}
|
||||
|
||||
_ = p.track.WriteSample(media.Sample{
|
||||
@@ -371,14 +379,7 @@ func (p *streamTransport) handleSample(sample []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
p.recvMu.Lock()
|
||||
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
|
||||
p.recvMu.Unlock()
|
||||
p.sendAck(frame.seq, frame.crc)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) {
|
||||
msg, ok := p.inbound[frame.seq]
|
||||
if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) {
|
||||
msg = &inboundMessage{
|
||||
@@ -389,33 +390,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
}
|
||||
p.inbound[frame.seq] = msg
|
||||
}
|
||||
|
||||
if int(frame.fragIdx) >= len(msg.frags) {
|
||||
p.recvMu.Unlock()
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if msg.frags[frame.fragIdx] == nil {
|
||||
chunk := make([]byte, len(frame.payload))
|
||||
copy(chunk, frame.payload)
|
||||
msg.frags[frame.fragIdx] = chunk
|
||||
msg.remain--
|
||||
}
|
||||
return msg, msg.remain == 0
|
||||
}
|
||||
|
||||
if msg.remain > 0 {
|
||||
func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte {
|
||||
data := make([]byte, 0, msg.totalLen)
|
||||
for _, frag := range msg.frags {
|
||||
data = append(data, frag...)
|
||||
}
|
||||
if uint32(len(data)) > msg.totalLen { //nolint:gosec
|
||||
data = data[:msg.totalLen]
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
p.recvMu.Lock()
|
||||
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
|
||||
p.recvMu.Unlock()
|
||||
p.sendAck(frame.seq, frame.crc)
|
||||
return
|
||||
}
|
||||
|
||||
msg, complete := p.upsertInbound(frame)
|
||||
if msg == nil || !complete {
|
||||
p.recvMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(p.inbound, frame.seq)
|
||||
data := make([]byte, 0, msg.totalLen)
|
||||
for _, frag := range msg.frags {
|
||||
data = append(data, frag...)
|
||||
}
|
||||
|
||||
if uint32(len(data)) > msg.totalLen {
|
||||
data = data[:msg.totalLen]
|
||||
}
|
||||
data := p.assembleMessage(msg)
|
||||
|
||||
if crc32.ChecksumIEEE(data) != msg.crc {
|
||||
p.recvMu.Unlock()
|
||||
@@ -480,9 +493,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
|
||||
out[5] = frameTypeData
|
||||
binary.BigEndian.PutUint32(out[6:10], seq)
|
||||
binary.BigEndian.PutUint32(out[10:14], crc)
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen))
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx))
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal))
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec
|
||||
copy(out[22:], payload)
|
||||
return out
|
||||
}
|
||||
@@ -499,27 +512,27 @@ func encodeAckFrame(seq, crc uint32) []byte {
|
||||
|
||||
func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
if len(data) < 6 {
|
||||
return transportFrame{}, fmt.Errorf("frame too short")
|
||||
return transportFrame{}, ErrFrameTooShort
|
||||
}
|
||||
if binary.BigEndian.Uint32(data[0:4]) != protocolMagic {
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame magic")
|
||||
return transportFrame{}, ErrUnexpectedMagic
|
||||
}
|
||||
if data[4] != protocolVersion {
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame version")
|
||||
return transportFrame{}, ErrUnexpectedVersion
|
||||
}
|
||||
|
||||
frame := transportFrame{typ: data[5]}
|
||||
switch frame.typ {
|
||||
case frameTypeAck:
|
||||
if len(data) < 14 {
|
||||
return transportFrame{}, fmt.Errorf("ack too short")
|
||||
return transportFrame{}, ErrAckTooShort
|
||||
}
|
||||
frame.seq = binary.BigEndian.Uint32(data[6:10])
|
||||
frame.crc = binary.BigEndian.Uint32(data[10:14])
|
||||
return frame, nil
|
||||
case frameTypeData:
|
||||
if len(data) < 22 {
|
||||
return transportFrame{}, fmt.Errorf("data too short")
|
||||
return transportFrame{}, ErrDataTooShort
|
||||
}
|
||||
frame.seq = binary.BigEndian.Uint32(data[6:10])
|
||||
frame.crc = binary.BigEndian.Uint32(data[10:14])
|
||||
@@ -529,6 +542,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
frame.payload = append([]byte(nil), data[22:]...)
|
||||
return frame, nil
|
||||
default:
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame type")
|
||||
return transportFrame{}, ErrUnexpectedFrameType
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,7 @@ import (
|
||||
|
||||
func TestSEIRoundTrip(t *testing.T) {
|
||||
payload := []byte("hello over seichannel")
|
||||
accessUnit, err := buildVideoAccessUnit(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("buildVideoAccessUnit failed: %v", err)
|
||||
}
|
||||
accessUnit := buildVideoAccessUnit(payload)
|
||||
|
||||
got, err := extractVideoPayloads(accessUnit)
|
||||
if err != nil {
|
||||
|
||||
@@ -58,6 +58,7 @@ type Config struct {
|
||||
// Factory creates a transport instance.
|
||||
type Factory func(ctx context.Context, cfg Config) (Transport, error)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var registry = make(map[string]Factory)
|
||||
|
||||
// Register adds a transport factory to the registry.
|
||||
|
||||
@@ -2,11 +2,13 @@ package videochannel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -27,6 +29,12 @@ var (
|
||||
ErrFFmpegUnavailable = errors.New("ffmpeg is required for videochannel")
|
||||
// ErrUnsupportedVideoCodec is returned when videochannel cannot decode the negotiated codec.
|
||||
ErrUnsupportedVideoCodec = errors.New("unsupported video codec")
|
||||
// ErrEncoderTimeout is returned when the encoder does not produce a frame within the deadline.
|
||||
ErrEncoderTimeout = errors.New("encoder timeout")
|
||||
// ErrPopFrameTimeout is returned when no decoded frame is available within the deadline.
|
||||
ErrPopFrameTimeout = errors.New("pop frame timeout")
|
||||
// ErrUnexpectedFrameSize is returned when the raw frame size does not match expectations.
|
||||
ErrUnexpectedFrameSize = errors.New("unexpected encoder frame size")
|
||||
)
|
||||
|
||||
type codecSpec struct {
|
||||
@@ -38,8 +46,7 @@ type codecSpec struct {
|
||||
encodeArgs []string
|
||||
}
|
||||
|
||||
func codecSpecForCarrier(carrier string) codecSpec {
|
||||
// Natural default for most WebRTC providers
|
||||
func codecSpecForCarrier(_ string) codecSpec {
|
||||
return vp8CodecSpec()
|
||||
}
|
||||
|
||||
@@ -120,6 +127,49 @@ func vp8CodecSpec() codecSpec {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveEncoderCodec(spec codecSpec, hw string) string {
|
||||
if hw != "nvenc" {
|
||||
return spec.encoder
|
||||
}
|
||||
switch spec.mimeType {
|
||||
case webrtc.MimeTypeH264:
|
||||
return "h264_nvenc"
|
||||
case webrtc.MimeTypeVP8:
|
||||
return "vp8_nvenc"
|
||||
case webrtc.MimeTypeVP9:
|
||||
return "vp9_nvenc"
|
||||
case webrtc.MimeTypeAV1:
|
||||
return "av1_nvenc"
|
||||
default:
|
||||
return spec.encoder
|
||||
}
|
||||
}
|
||||
|
||||
func buildEncoderArgs(spec codecSpec, vcodec string, width, height, fps int, bitrate string) []string {
|
||||
args := []string{
|
||||
"-loglevel", "error", "-threads", "1",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "gray",
|
||||
"-video_size", strconv.Itoa(width) + "x" + strconv.Itoa(height),
|
||||
"-framerate", strconv.Itoa(fps),
|
||||
"-i", "pipe:0",
|
||||
"-an",
|
||||
}
|
||||
|
||||
if strings.HasSuffix(vcodec, "_nvenc") {
|
||||
args = append(args, "-c:v", vcodec, "-preset", "p1", "-tune", "ull", "-rc", "vbr")
|
||||
} else {
|
||||
args = append(args, spec.encodeArgs...)
|
||||
}
|
||||
|
||||
args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate)
|
||||
|
||||
if spec.mimeType == webrtc.MimeTypeH264 {
|
||||
return append(args, "-f", "h264", "pipe:1")
|
||||
}
|
||||
return append(args, "-f", "ivf", "pipe:1")
|
||||
}
|
||||
|
||||
type ffmpegEncoder struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
@@ -134,62 +184,20 @@ type ffmpegEncoder struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string) (*ffmpegEncoder, error) {
|
||||
func newFFmpegEncoder(
|
||||
ctx context.Context,
|
||||
spec codecSpec,
|
||||
width, height, fps int,
|
||||
bitrate, hw string,
|
||||
) (*ffmpegEncoder, error) {
|
||||
if _, err := exec.LookPath("ffmpeg"); err != nil {
|
||||
return nil, ErrFFmpegUnavailable
|
||||
}
|
||||
|
||||
args := []string{"-loglevel", "error", "-threads", "1"}
|
||||
vcodec := resolveEncoderCodec(spec, hw)
|
||||
args := buildEncoderArgs(spec, vcodec, width, height, fps, bitrate)
|
||||
|
||||
// Determine encoder binary based on HW flag
|
||||
vcodec := spec.encoder
|
||||
if hw == "nvenc" {
|
||||
switch spec.mimeType {
|
||||
case webrtc.MimeTypeH264:
|
||||
vcodec = "h264_nvenc"
|
||||
case webrtc.MimeTypeVP8:
|
||||
vcodec = "vp8_nvenc"
|
||||
case webrtc.MimeTypeVP9:
|
||||
vcodec = "vp9_nvenc"
|
||||
case webrtc.MimeTypeAV1:
|
||||
vcodec = "av1_nvenc"
|
||||
}
|
||||
}
|
||||
|
||||
inputPixFmt := "gray"
|
||||
frameSize := width * height
|
||||
|
||||
args = append(args,
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", inputPixFmt,
|
||||
"-video_size", fmt.Sprintf("%dx%d", width, height),
|
||||
"-framerate", fmt.Sprintf("%d", fps),
|
||||
"-i", "pipe:0",
|
||||
"-an",
|
||||
)
|
||||
|
||||
// Apply hardware specific flags if using NVENC
|
||||
if strings.HasSuffix(vcodec, "_nvenc") {
|
||||
args = append(args,
|
||||
"-c:v", vcodec,
|
||||
"-preset", "p1",
|
||||
"-tune", "ull",
|
||||
"-rc", "vbr",
|
||||
)
|
||||
} else {
|
||||
// Use software encoder args from spec
|
||||
args = append(args, spec.encodeArgs...)
|
||||
}
|
||||
|
||||
args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate)
|
||||
|
||||
if spec.mimeType == webrtc.MimeTypeH264 {
|
||||
args = append(args, "-f", "h264", "pipe:1")
|
||||
} else {
|
||||
args = append(args, "-f", "ivf", "pipe:1")
|
||||
}
|
||||
|
||||
cmd := exec.Command("ffmpeg", args...)
|
||||
cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoder stdin: %w", err)
|
||||
@@ -212,7 +220,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string
|
||||
frames: make(chan []byte, 8),
|
||||
width: width,
|
||||
height: height,
|
||||
frameSize: frameSize,
|
||||
frameSize: width * height,
|
||||
}
|
||||
|
||||
if spec.mimeType == webrtc.MimeTypeH264 {
|
||||
@@ -225,7 +233,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string
|
||||
|
||||
func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) {
|
||||
if len(frame) != e.frameSize {
|
||||
return nil, fmt.Errorf("unexpected encoder frame size: %d (expected %d)", len(frame), e.frameSize)
|
||||
return nil, fmt.Errorf("%w: got %d expected %d", ErrUnexpectedFrameSize, len(frame), e.frameSize)
|
||||
}
|
||||
if err := e.processErr(); err != nil {
|
||||
return nil, err
|
||||
@@ -244,7 +252,7 @@ func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) {
|
||||
if err := e.processErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("encoder timeout")
|
||||
return nil, ErrEncoderTimeout
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,6 +335,43 @@ func (e *ffmpegEncoder) processErr() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveDecoderName(spec codecSpec, hw string) string {
|
||||
if hw != "nvenc" {
|
||||
return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
|
||||
}
|
||||
switch spec.mimeType {
|
||||
case webrtc.MimeTypeH264:
|
||||
return "h264_cuvid"
|
||||
case webrtc.MimeTypeVP8:
|
||||
return "vp8_cuvid"
|
||||
case webrtc.MimeTypeVP9:
|
||||
return "vp9_cuvid"
|
||||
default:
|
||||
return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
|
||||
}
|
||||
}
|
||||
|
||||
func buildDecoderArgs(spec codecSpec, decoderName string, width, height int, outputPixFmt string) []string {
|
||||
args := []string{"-loglevel", "error", "-threads", "1"}
|
||||
if spec.mimeType == webrtc.MimeTypeH264 {
|
||||
args = append(args, "-f", "h264")
|
||||
} else {
|
||||
args = append(args, "-f", "ivf")
|
||||
}
|
||||
|
||||
vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt)
|
||||
return append(args,
|
||||
"-flags", "low_delay",
|
||||
"-vcodec", decoderName,
|
||||
"-i", "pipe:0",
|
||||
"-an",
|
||||
"-vf", vfFilter,
|
||||
"-pix_fmt", outputPixFmt,
|
||||
"-f", "rawvideo",
|
||||
"pipe:1",
|
||||
)
|
||||
}
|
||||
|
||||
type ffmpegDecoder struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
@@ -341,46 +386,20 @@ type ffmpegDecoder struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpegDecoder, error) {
|
||||
func newFFmpegDecoder(
|
||||
ctx context.Context,
|
||||
spec codecSpec,
|
||||
width, height, fps int,
|
||||
hw string,
|
||||
) (*ffmpegDecoder, error) {
|
||||
if _, err := exec.LookPath("ffmpeg"); err != nil {
|
||||
return nil, ErrFFmpegUnavailable
|
||||
}
|
||||
|
||||
decoderName := strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
|
||||
if hw == "nvenc" {
|
||||
switch spec.mimeType {
|
||||
case webrtc.MimeTypeH264:
|
||||
decoderName = "h264_cuvid"
|
||||
case webrtc.MimeTypeVP8:
|
||||
decoderName = "vp8_cuvid"
|
||||
case webrtc.MimeTypeVP9:
|
||||
decoderName = "vp9_cuvid"
|
||||
}
|
||||
}
|
||||
decoderName := resolveDecoderName(spec, hw)
|
||||
args := buildDecoderArgs(spec, decoderName, width, height, "gray")
|
||||
|
||||
outputPixFmt := "gray"
|
||||
frameSize := width * height
|
||||
|
||||
args := []string{"-loglevel", "error", "-threads", "1"}
|
||||
if spec.mimeType == webrtc.MimeTypeH264 {
|
||||
args = append(args, "-f", "h264")
|
||||
} else {
|
||||
args = append(args, "-f", "ivf")
|
||||
}
|
||||
|
||||
vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt)
|
||||
args = append(args,
|
||||
"-flags", "low_delay",
|
||||
"-vcodec", decoderName,
|
||||
"-i", "pipe:0",
|
||||
"-an",
|
||||
"-vf", vfFilter,
|
||||
"-pix_fmt", outputPixFmt,
|
||||
"-f", "rawvideo",
|
||||
"pipe:1",
|
||||
)
|
||||
|
||||
cmd := exec.Command("ffmpeg", args...)
|
||||
cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoder stdin: %w", err)
|
||||
@@ -402,7 +421,7 @@ func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpe
|
||||
stderr: stderr,
|
||||
frames: make(chan []byte, 32),
|
||||
mimeType: spec.mimeType,
|
||||
frameSize: frameSize,
|
||||
frameSize: width * height,
|
||||
}
|
||||
|
||||
if spec.mimeType != webrtc.MimeTypeH264 {
|
||||
@@ -441,7 +460,7 @@ func (d *ffmpegDecoder) PopFrame() ([]byte, error) {
|
||||
}
|
||||
return frame, nil
|
||||
case <-time.After(10 * time.Second):
|
||||
return nil, fmt.Errorf("pop frame timeout")
|
||||
return nil, ErrPopFrameTimeout
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,9 +534,9 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er
|
||||
binary.LittleEndian.PutUint16(header[4:6], 0)
|
||||
binary.LittleEndian.PutUint16(header[6:8], 32)
|
||||
copy(header[8:12], []byte(fourCC))
|
||||
binary.LittleEndian.PutUint16(header[12:14], uint16(width))
|
||||
binary.LittleEndian.PutUint16(header[14:16], uint16(height))
|
||||
binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate))
|
||||
binary.LittleEndian.PutUint16(header[12:14], uint16(width)) //nolint:gosec
|
||||
binary.LittleEndian.PutUint16(header[14:16], uint16(height)) //nolint:gosec
|
||||
binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate)) //nolint:gosec
|
||||
binary.LittleEndian.PutUint32(header[20:24], 1)
|
||||
binary.LittleEndian.PutUint32(header[24:28], 0)
|
||||
binary.LittleEndian.PutUint32(header[28:32], 0)
|
||||
@@ -526,7 +545,7 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er
|
||||
|
||||
func writeIVFFrame(w io.Writer, pts uint64, frame []byte) error {
|
||||
header := make([]byte, 12)
|
||||
binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame)))
|
||||
binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame))) //nolint:gosec
|
||||
binary.LittleEndian.PutUint64(header[4:12], pts)
|
||||
if err := writeAll(w, header); err != nil {
|
||||
return err
|
||||
@@ -538,9 +557,10 @@ func writeAll(w io.Writer, data []byte) error {
|
||||
for len(data) > 0 {
|
||||
n, err := w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package videochannel
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -12,6 +12,21 @@ const (
|
||||
frameTypeAck byte = 2
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrFrameTooShort is returned when the received frame is too short to decode.
|
||||
ErrFrameTooShort = errors.New("frame too short")
|
||||
// ErrUnexpectedMagic is returned when the frame magic bytes do not match.
|
||||
ErrUnexpectedMagic = errors.New("unexpected frame magic")
|
||||
// ErrUnexpectedVersion is returned when the frame protocol version does not match.
|
||||
ErrUnexpectedVersion = errors.New("unexpected frame version")
|
||||
// ErrAckTooShort is returned when the ack frame is shorter than expected.
|
||||
ErrAckTooShort = errors.New("ack frame too short")
|
||||
// ErrDataTooShort is returned when the data frame is shorter than expected.
|
||||
ErrDataTooShort = errors.New("data frame too short")
|
||||
// ErrUnexpectedFrameType is returned for unknown frame type bytes.
|
||||
ErrUnexpectedFrameType = errors.New("unexpected frame type")
|
||||
)
|
||||
|
||||
type transportFrame struct {
|
||||
typ byte
|
||||
seq uint32
|
||||
@@ -56,9 +71,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
|
||||
out[5] = frameTypeData
|
||||
binary.BigEndian.PutUint32(out[6:10], seq)
|
||||
binary.BigEndian.PutUint32(out[10:14], crc)
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen))
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx))
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal))
|
||||
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec
|
||||
copy(out[22:], payload)
|
||||
return out
|
||||
}
|
||||
@@ -75,27 +90,27 @@ func encodeAckFrame(seq, crc uint32) []byte {
|
||||
|
||||
func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
if len(data) < 6 {
|
||||
return transportFrame{}, fmt.Errorf("frame too short")
|
||||
return transportFrame{}, ErrFrameTooShort
|
||||
}
|
||||
if binary.BigEndian.Uint32(data[0:4]) != protocolMagic {
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame magic")
|
||||
return transportFrame{}, ErrUnexpectedMagic
|
||||
}
|
||||
if data[4] != protocolVersion {
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame version")
|
||||
return transportFrame{}, ErrUnexpectedVersion
|
||||
}
|
||||
|
||||
frame := transportFrame{typ: data[5]}
|
||||
switch frame.typ {
|
||||
case frameTypeAck:
|
||||
if len(data) < 14 {
|
||||
return transportFrame{}, fmt.Errorf("ack too short")
|
||||
return transportFrame{}, ErrAckTooShort
|
||||
}
|
||||
frame.seq = binary.BigEndian.Uint32(data[6:10])
|
||||
frame.crc = binary.BigEndian.Uint32(data[10:14])
|
||||
return frame, nil
|
||||
case frameTypeData:
|
||||
if len(data) < 22 {
|
||||
return transportFrame{}, fmt.Errorf("data too short")
|
||||
return transportFrame{}, ErrDataTooShort
|
||||
}
|
||||
frame.seq = binary.BigEndian.Uint32(data[6:10])
|
||||
frame.crc = binary.BigEndian.Uint32(data[10:14])
|
||||
@@ -105,6 +120,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
|
||||
frame.payload = append([]byte(nil), data[22:]...)
|
||||
return frame, nil
|
||||
default:
|
||||
return transportFrame{}, fmt.Errorf("unexpected frame type")
|
||||
return transportFrame{}, ErrUnexpectedFrameType
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,9 +70,8 @@ type streamTransport struct {
|
||||
videoCodec string
|
||||
videoTileModule int
|
||||
videoTileRS int
|
||||
runCtx context.Context //nolint:containedctx
|
||||
|
||||
// cached encoded idle frame — rendered and encoded once, reused on every tick
|
||||
// where the outbound queue is empty to avoid re-encoding an identical blank frame.
|
||||
idleFrame []byte
|
||||
idleFrameMu sync.Mutex
|
||||
}
|
||||
@@ -144,6 +143,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
|
||||
videoCodec: cfg.VideoCodec,
|
||||
videoTileModule: tileModule,
|
||||
videoTileRS: tileRS,
|
||||
runCtx: ctx,
|
||||
}
|
||||
|
||||
if err := stream.AddTrack(track); err != nil {
|
||||
@@ -159,14 +159,14 @@ func (p *streamTransport) Connect(ctx context.Context) error {
|
||||
connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW)
|
||||
encoder, err := newFFmpegEncoder(ctx, p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("new encoder: %w", err)
|
||||
}
|
||||
|
||||
if err := p.stream.Connect(connectCtx); err != nil {
|
||||
_ = encoder.Close()
|
||||
return err
|
||||
return fmt.Errorf("connect stream: %w", err)
|
||||
}
|
||||
|
||||
p.encoderMu.Lock()
|
||||
@@ -212,7 +212,7 @@ func (p *streamTransport) Send(data []byte) error {
|
||||
p.ackMu.Unlock()
|
||||
}()
|
||||
|
||||
for attempt := 0; attempt < maxSendAttempts; attempt++ {
|
||||
for range maxSendAttempts {
|
||||
for idx, fragment := range fragments {
|
||||
frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment)
|
||||
if err := p.enqueueFrame(frame, false); err != nil {
|
||||
@@ -257,7 +257,9 @@ func (p *streamTransport) Close() error {
|
||||
if p.writerUp.Load() {
|
||||
<-p.writerDone
|
||||
}
|
||||
return p.stream.Close()
|
||||
if err := p.stream.Close(); err != nil {
|
||||
return fmt.Errorf("close stream: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -301,6 +303,47 @@ func (p *streamTransport) Features() transport.Features {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) writeIdleFrame(enc *ffmpegEncoder, frameDuration time.Duration) {
|
||||
p.idleFrameMu.Lock()
|
||||
cached := p.idleFrame
|
||||
p.idleFrameMu.Unlock()
|
||||
|
||||
if cached == nil {
|
||||
rawFrame, err := p.renderFrame(nil)
|
||||
if err != nil {
|
||||
logger.Debugf("videochannel render idle error: %v", err)
|
||||
return
|
||||
}
|
||||
sample, err := enc.EncodeFrame(rawFrame)
|
||||
if err != nil {
|
||||
logger.Warnf("videochannel encoder idle error: %v", err)
|
||||
return
|
||||
}
|
||||
p.idleFrameMu.Lock()
|
||||
p.idleFrame = sample
|
||||
p.idleFrameMu.Unlock()
|
||||
cached = sample
|
||||
}
|
||||
|
||||
_ = p.track.WriteSample(media.Sample{Data: cached, Duration: frameDuration})
|
||||
}
|
||||
|
||||
func (p *streamTransport) writePayloadFrame(enc *ffmpegEncoder, payload []byte, frameDuration time.Duration) {
|
||||
rawFrame, err := p.renderFrame(payload)
|
||||
if err != nil {
|
||||
logger.Debugf("videochannel render error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sample, err := enc.EncodeFrame(rawFrame)
|
||||
if err != nil {
|
||||
logger.Warnf("videochannel encoder error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = p.track.WriteSample(media.Sample{Data: sample, Duration: frameDuration})
|
||||
}
|
||||
|
||||
func (p *streamTransport) writerLoop() {
|
||||
defer close(p.writerDone)
|
||||
defer func() {
|
||||
@@ -334,58 +377,24 @@ func (p *streamTransport) writerLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
// idle frame: payload is nil — reuse previously encoded sample to avoid
|
||||
// re-rendering and re-encoding an identical blank frame every tick.
|
||||
if payload == nil {
|
||||
p.idleFrameMu.Lock()
|
||||
cached := p.idleFrame
|
||||
p.idleFrameMu.Unlock()
|
||||
|
||||
if cached == nil {
|
||||
// first time — render + encode once, then cache
|
||||
rawFrame, err := renderVisualFrame(nil, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS)
|
||||
if err != nil {
|
||||
logger.Debugf("videochannel render idle error: %v", err)
|
||||
continue
|
||||
}
|
||||
sample, err := enc.EncodeFrame(rawFrame)
|
||||
if err != nil {
|
||||
logger.Warnf("videochannel encoder idle error: %v", err)
|
||||
continue
|
||||
}
|
||||
p.idleFrameMu.Lock()
|
||||
p.idleFrame = sample
|
||||
p.idleFrameMu.Unlock()
|
||||
cached = sample
|
||||
}
|
||||
|
||||
_ = p.track.WriteSample(media.Sample{
|
||||
Data: cached,
|
||||
Duration: frameDuration,
|
||||
})
|
||||
continue
|
||||
p.writeIdleFrame(enc, frameDuration)
|
||||
} else {
|
||||
p.writePayloadFrame(enc, payload, frameDuration)
|
||||
}
|
||||
|
||||
rawFrame, err := renderVisualFrame(payload, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS)
|
||||
if err != nil {
|
||||
logger.Debugf("videochannel render error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
sample, err := enc.EncodeFrame(rawFrame)
|
||||
if err != nil {
|
||||
logger.Warnf("videochannel encoder error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_ = p.track.WriteSample(media.Sample{
|
||||
Data: sample,
|
||||
Duration: frameDuration,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) renderFrame(payload []byte) ([]byte, error) {
|
||||
return renderVisualFrame(
|
||||
payload,
|
||||
p.videoW, p.videoH,
|
||||
p.videoCodec, p.videoQRRecovery,
|
||||
p.videoTileModule, p.videoTileRS,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *streamTransport) nextOutboundFrame() ([]byte, bool) {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
@@ -425,6 +434,61 @@ func (p *streamTransport) enqueueFrame(frame []byte, priority bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) popDecoderFrames(decoder *ffmpegDecoder) {
|
||||
defer func() {
|
||||
p.decoderMu.Lock()
|
||||
if p.decoder == decoder {
|
||||
p.decoder = nil
|
||||
}
|
||||
p.decoderMu.Unlock()
|
||||
_ = decoder.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
frame, err := decoder.PopFrame()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrTransportClosed) && !p.closed.Load() {
|
||||
logger.Warnf("videochannel decoder pop error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
p.handleFrame(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) readDecoderInput(track *webrtc.TrackRemote, decoder *ffmpegDecoder, codec codecSpec) {
|
||||
sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate)
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
packet, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
sb.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
sb.Push(packet)
|
||||
for sample := sb.Pop(); sample != nil; sample = sb.Pop() {
|
||||
if err := decoder.PushSample(sample.Data); err != nil {
|
||||
if !p.closed.Load() {
|
||||
logger.Warnf("videochannel decoder push error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
|
||||
codec, ok := codecSpecForMime(track.Codec().MimeType)
|
||||
if !ok {
|
||||
@@ -432,7 +496,7 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc
|
||||
return
|
||||
}
|
||||
|
||||
decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS, p.videoHW)
|
||||
decoder, err := newFFmpegDecoder(p.runCtx, codec, p.videoW, p.videoH, p.videoFPS, p.videoHW)
|
||||
if err != nil {
|
||||
logger.Warnf("videochannel decoder init failed: %v", err)
|
||||
return
|
||||
@@ -450,60 +514,8 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc
|
||||
p.decoder = decoder
|
||||
p.decoderMu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
p.decoderMu.Lock()
|
||||
if p.decoder == decoder {
|
||||
p.decoder = nil
|
||||
}
|
||||
p.decoderMu.Unlock()
|
||||
_ = decoder.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
frame, err := decoder.PopFrame()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrTransportClosed) && !p.closed.Load() {
|
||||
logger.Warnf("videochannel decoder pop error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
p.handleFrame(frame)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate)
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
packet, _, err := track.ReadRTP()
|
||||
if err != nil {
|
||||
sb.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
sb.Push(packet)
|
||||
for sample := sb.Pop(); sample != nil; sample = sb.Pop() {
|
||||
if err := decoder.PushSample(sample.Data); err != nil {
|
||||
if !p.closed.Load() {
|
||||
logger.Warnf("videochannel decoder push error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
go p.popDecoderFrames(decoder)
|
||||
go p.readDecoderInput(track, decoder, codec)
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleFrame(frame []byte) {
|
||||
@@ -531,14 +543,7 @@ func (p *streamTransport) handleFrame(frame []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
p.recvMu.Lock()
|
||||
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
|
||||
p.recvMu.Unlock()
|
||||
p.sendAck(frame.seq, frame.crc)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) {
|
||||
msg, ok := p.inbound[frame.seq]
|
||||
if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) {
|
||||
msg = &inboundMessage{
|
||||
@@ -549,33 +554,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
}
|
||||
p.inbound[frame.seq] = msg
|
||||
}
|
||||
|
||||
if int(frame.fragIdx) >= len(msg.frags) {
|
||||
p.recvMu.Unlock()
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if msg.frags[frame.fragIdx] == nil {
|
||||
chunk := make([]byte, len(frame.payload))
|
||||
copy(chunk, frame.payload)
|
||||
msg.frags[frame.fragIdx] = chunk
|
||||
msg.remain--
|
||||
}
|
||||
return msg, msg.remain == 0
|
||||
}
|
||||
|
||||
if msg.remain > 0 {
|
||||
func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte {
|
||||
data := make([]byte, 0, msg.totalLen)
|
||||
for _, frag := range msg.frags {
|
||||
data = append(data, frag...)
|
||||
}
|
||||
if uint32(len(data)) > msg.totalLen { //nolint:gosec
|
||||
data = data[:msg.totalLen]
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
|
||||
p.recvMu.Lock()
|
||||
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
|
||||
p.recvMu.Unlock()
|
||||
p.sendAck(frame.seq, frame.crc)
|
||||
return
|
||||
}
|
||||
|
||||
msg, complete := p.upsertInbound(frame)
|
||||
if msg == nil || !complete {
|
||||
p.recvMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(p.inbound, frame.seq)
|
||||
data := make([]byte, 0, msg.totalLen)
|
||||
for _, frag := range msg.frags {
|
||||
data = append(data, frag...)
|
||||
}
|
||||
|
||||
if uint32(len(data)) > msg.totalLen {
|
||||
data = data[:msg.totalLen]
|
||||
}
|
||||
data := p.assembleMessage(msg)
|
||||
|
||||
if crc32.ChecksumIEEE(data) != msg.crc {
|
||||
p.recvMu.Unlock()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package videochannel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -8,6 +9,9 @@ import (
|
||||
grtile "github.com/zarazaex69/gr/tile"
|
||||
)
|
||||
|
||||
// ErrUnexpectedQRFrameSize is returned when the decoded frame size does not match the expected dimensions.
|
||||
var ErrUnexpectedQRFrameSize = errors.New("unexpected qr frame size")
|
||||
|
||||
func eccLevel(level string) grqr.ECCLevel {
|
||||
switch level {
|
||||
case "medium":
|
||||
@@ -21,7 +25,12 @@ func eccLevel(level string) grqr.ECCLevel {
|
||||
}
|
||||
}
|
||||
|
||||
func renderVisualFrame(payload []byte, width, height int, codec, recoveryLevel string, tileModule, tileRS int) ([]byte, error) {
|
||||
func renderVisualFrame(
|
||||
payload []byte,
|
||||
width, height int,
|
||||
codec, recoveryLevel string,
|
||||
tileModule, tileRS int,
|
||||
) ([]byte, error) {
|
||||
if codec == "tile" {
|
||||
return renderTileFrame(payload, tileModule, tileRS)
|
||||
}
|
||||
@@ -47,7 +56,11 @@ func renderQRFrame(payload []byte, width, height int, recoveryLevel string) ([]b
|
||||
return nil, fmt.Errorf("qr codec: %w", err)
|
||||
}
|
||||
|
||||
return c.Encode(payload)
|
||||
result, err := c.Encode(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qr encode: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) {
|
||||
@@ -64,7 +77,11 @@ func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) {
|
||||
return nil, fmt.Errorf("tile codec: %w", err)
|
||||
}
|
||||
|
||||
return c.Encode(payload, 0, 1)
|
||||
result, err := c.Encode(payload, 0, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tile encode: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractVisualPayload(frame []byte, width, height int, codec string, tileModule, tileRS int) ([]byte, error) {
|
||||
@@ -76,7 +93,8 @@ func extractVisualPayload(frame []byte, width, height int, codec string, tileMod
|
||||
|
||||
func extractQRPayload(frame []byte, width, height int) ([]byte, error) {
|
||||
if len(frame) != width*height {
|
||||
return nil, fmt.Errorf("unexpected frame size: %d (expected %dx%d=%d)", len(frame), width, height, width*height)
|
||||
return nil, fmt.Errorf("%w: got %d expected %dx%d=%d",
|
||||
ErrUnexpectedQRFrameSize, len(frame), width, height, width*height)
|
||||
}
|
||||
|
||||
c, err := grqr.New(grqr.Config{
|
||||
@@ -111,7 +129,7 @@ func extractTilePayload(frame []byte, tileModule, tileRS int) ([]byte, error) {
|
||||
|
||||
result, err := c.Decode(frame)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
return result.Payload, nil
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package vp8channel provides byte transport over VP8 video frames using KCP.
|
||||
package vp8channel
|
||||
|
||||
import (
|
||||
@@ -58,7 +59,7 @@ type kcpRuntime struct {
|
||||
func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) {
|
||||
c := newKCPConn(out, inboundQueueSize)
|
||||
|
||||
sess, err := kcp.NewConn3(kcpConvID, fakeAddr, nil, 0, 0, c)
|
||||
sess, err := kcp.NewConn3(kcpConvID, fakeUDPAddr(), nil, 0, 0, c)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, fmt.Errorf("kcp new conn: %w", err)
|
||||
@@ -71,7 +72,6 @@ func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) {
|
||||
sess.SetNoDelay(1, 10, 2, 1)
|
||||
sess.SetWindowSize(kcpSndWnd, kcpRcvWnd)
|
||||
sess.SetMtu(kcpMTU)
|
||||
sess.SetStreamMode(true) // see kcpLenPrefix comment above
|
||||
sess.SetACKNoDelay(true)
|
||||
sess.SetWriteDelay(false)
|
||||
|
||||
@@ -127,16 +127,17 @@ func (r *kcpRuntime) send(msg []byte) error {
|
||||
return ErrKCPMessageTooLarge
|
||||
}
|
||||
var hdr [kcpLenPrefix]byte
|
||||
//nolint:gosec
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(msg)))
|
||||
|
||||
r.writeMu.Lock()
|
||||
defer r.writeMu.Unlock()
|
||||
|
||||
if _, err := r.sess.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("kcp write header: %w", err)
|
||||
}
|
||||
if _, err := r.sess.Write(msg); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("kcp write payload: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,10 +6,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeAddr is a placeholder address used by the KCP session. The underlying
|
||||
// "packet conn" is a point-to-point pipe over the VP8 carrier and has no real
|
||||
// notion of an address, but kcp-go's API requires one.
|
||||
var fakeAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1}
|
||||
func fakeUDPAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1}
|
||||
}
|
||||
|
||||
// kcpConn is a net.PacketConn implementation that bridges kcp-go on top of
|
||||
// the vp8channel byte-message carrier.
|
||||
@@ -62,7 +61,7 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
if !deadline.IsZero() {
|
||||
d := time.Until(deadline)
|
||||
if d <= 0 {
|
||||
return 0, nil, errTimeout{}
|
||||
return 0, nil, TimeoutError{}
|
||||
}
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
@@ -72,11 +71,11 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
select {
|
||||
case msg := <-c.in:
|
||||
n := copy(p, msg)
|
||||
return n, fakeAddr, nil
|
||||
return n, fakeUDPAddr(), nil
|
||||
case <-c.closed:
|
||||
return 0, nil, net.ErrClosed
|
||||
case <-timerC:
|
||||
return 0, nil, errTimeout{}
|
||||
return 0, nil, TimeoutError{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +91,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) {
|
||||
if !deadline.IsZero() {
|
||||
d := time.Until(deadline)
|
||||
if d <= 0 {
|
||||
return 0, errTimeout{}
|
||||
return 0, TimeoutError{}
|
||||
}
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
@@ -105,7 +104,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) {
|
||||
case <-c.closed:
|
||||
return 0, net.ErrClosed
|
||||
case <-timerC:
|
||||
return 0, errTimeout{}
|
||||
return 0, TimeoutError{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +113,7 @@ func (c *kcpConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *kcpConn) LocalAddr() net.Addr { return fakeAddr }
|
||||
func (c *kcpConn) LocalAddr() net.Addr { return fakeUDPAddr() }
|
||||
|
||||
func (c *kcpConn) SetDeadline(t time.Time) error {
|
||||
_ = c.SetReadDeadline(t)
|
||||
@@ -136,8 +135,13 @@ func (c *kcpConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type errTimeout struct{}
|
||||
// TimeoutError is a net.Error indicating a deadline exceeded.
|
||||
type TimeoutError struct{}
|
||||
|
||||
func (errTimeout) Error() string { return "i/o timeout" }
|
||||
func (errTimeout) Timeout() bool { return true }
|
||||
func (errTimeout) Temporary() bool { return true }
|
||||
func (TimeoutError) Error() string { return "i/o timeout" }
|
||||
|
||||
// Timeout reports that this error is a timeout.
|
||||
func (TimeoutError) Timeout() bool { return true }
|
||||
|
||||
// Temporary reports that this error is temporary.
|
||||
func (TimeoutError) Temporary() bool { return true }
|
||||
|
||||
@@ -27,14 +27,13 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrVideoTrackUnsupported is returned when a carrier cannot expose video tracks.
|
||||
ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks")
|
||||
ErrTransportClosed = errors.New("vp8channel transport closed")
|
||||
// ErrTransportClosed is returned when operations are attempted on a closed transport.
|
||||
ErrTransportClosed = errors.New("vp8channel transport closed")
|
||||
)
|
||||
|
||||
// vp8Keepalive is a minimal VP8 keyframe used as idle filler so that the SFU
|
||||
// keeps the track flowing when KCP has nothing to send. It is never delivered
|
||||
// to KCP because KCP packets always start with the convid (0xC0FFEE01 LE)
|
||||
// and would never collide with this keyframe payload.
|
||||
//nolint:gochecknoglobals
|
||||
var vp8Keepalive = []byte{
|
||||
0x30, 0x01, 0x00, 0x9d, 0x01, 0x2a, 0x10, 0x00,
|
||||
0x10, 0x00, 0x00, 0x47, 0x08, 0x85, 0x85, 0x88,
|
||||
@@ -64,6 +63,7 @@ type streamTransport struct {
|
||||
kcpMu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a vp8channel transport backed by a carrier-specific provider.
|
||||
func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) {
|
||||
session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{
|
||||
RoomURL: cfg.RoomURL,
|
||||
@@ -126,7 +126,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
|
||||
defer cancel()
|
||||
|
||||
if err := p.stream.Connect(connectCtx); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("connect stream: %w", err)
|
||||
}
|
||||
|
||||
var startErr error
|
||||
@@ -179,7 +179,9 @@ func (p *streamTransport) Close() error {
|
||||
if p.writerUp.Load() {
|
||||
<-p.writerDone
|
||||
}
|
||||
return p.stream.Close()
|
||||
if err := p.stream.Close(); err != nil {
|
||||
return fmt.Errorf("close stream: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -302,14 +304,62 @@ func (p *streamTransport) drainTrack(track *webrtc.TrackRemote) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
|
||||
var vp8Pkt codecs.VP8Packet
|
||||
var frameBuf []byte
|
||||
buf := make([]byte, rtpBufSize)
|
||||
type vp8FrameState struct {
|
||||
vp8Pkt codecs.VP8Packet
|
||||
frameBuf []byte
|
||||
lastSeq uint16
|
||||
haveLastSeq bool
|
||||
frameValid bool
|
||||
}
|
||||
|
||||
var lastSeq uint16
|
||||
var haveLastSeq bool
|
||||
frameValid := false
|
||||
// processRTPPacket returns a complete KCP frame when the VP8 frame is fully assembled, nil otherwise.
|
||||
// Detects packet loss/reordering to avoid silently corrupting fragmented VP8 frames.
|
||||
func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte {
|
||||
if s.haveLastSeq && pkt.SequenceNumber != s.lastSeq+1 {
|
||||
s.frameValid = false
|
||||
s.frameBuf = s.frameBuf[:0]
|
||||
}
|
||||
s.lastSeq = pkt.SequenceNumber
|
||||
s.haveLastSeq = true
|
||||
|
||||
vp8Payload, err := s.vp8Pkt.Unmarshal(pkt.Payload)
|
||||
if err != nil {
|
||||
s.frameValid = false
|
||||
s.frameBuf = s.frameBuf[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.vp8Pkt.S == 1 {
|
||||
s.frameBuf = s.frameBuf[:0]
|
||||
s.frameValid = true
|
||||
}
|
||||
|
||||
if !s.frameValid {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.frameBuf = append(s.frameBuf, vp8Payload...)
|
||||
|
||||
if !pkt.Marker {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
s.frameBuf = s.frameBuf[:0]
|
||||
s.frameValid = false
|
||||
}()
|
||||
|
||||
if len(s.frameBuf) >= 4 && s.frameBuf[0] == kcpMagic {
|
||||
frame := make([]byte, len(s.frameBuf))
|
||||
copy(frame, s.frameBuf)
|
||||
return frame
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
|
||||
var state vp8FrameState
|
||||
buf := make([]byte, rtpBufSize)
|
||||
|
||||
for {
|
||||
n, _, err := track.Read(buf)
|
||||
@@ -322,54 +372,16 @@ func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Detect packet loss / reordering. A single missing RTP packet
|
||||
// inside a fragmented VP8 frame would otherwise silently corrupt
|
||||
// the assembled payload (and bleed into the next frame). KCP can
|
||||
// recover from full-frame drops, but only if the frames it does
|
||||
// receive are byte-perfect.
|
||||
if haveLastSeq {
|
||||
expected := lastSeq + 1
|
||||
if pkt.SequenceNumber != expected {
|
||||
frameValid = false
|
||||
frameBuf = frameBuf[:0]
|
||||
}
|
||||
}
|
||||
lastSeq = pkt.SequenceNumber
|
||||
haveLastSeq = true
|
||||
|
||||
vp8Payload, err := vp8Pkt.Unmarshal(pkt.Payload)
|
||||
if err != nil {
|
||||
frameValid = false
|
||||
frameBuf = frameBuf[:0]
|
||||
frame := state.processRTPPacket(pkt)
|
||||
if frame == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if vp8Pkt.S == 1 {
|
||||
frameBuf = frameBuf[:0]
|
||||
frameValid = true
|
||||
}
|
||||
|
||||
if !frameValid {
|
||||
continue
|
||||
}
|
||||
|
||||
frameBuf = append(frameBuf, vp8Payload...)
|
||||
|
||||
if pkt.Marker {
|
||||
if len(frameBuf) >= 4 && frameBuf[0] == kcpMagic {
|
||||
p.kcpMu.RLock()
|
||||
rt := p.kcp
|
||||
p.kcpMu.RUnlock()
|
||||
if rt != nil {
|
||||
// Copy out of the shared frame buffer before handing
|
||||
// the payload off — KCP's deliver path is async.
|
||||
payload := make([]byte, len(frameBuf))
|
||||
copy(payload, frameBuf)
|
||||
rt.deliver(payload)
|
||||
}
|
||||
}
|
||||
frameBuf = frameBuf[:0]
|
||||
frameValid = false
|
||||
p.kcpMu.RLock()
|
||||
rt := p.kcp
|
||||
p.kcpMu.RUnlock()
|
||||
if rt != nil {
|
||||
rt.deliver(frame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,16 +7,64 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func pumpPackets(stop <-chan struct{}, from <-chan []byte, to *kcpRuntime) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case pkt := <-from:
|
||||
to.deliver(pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkMessages(t *testing.T, got, want [][]byte) {
|
||||
t.Helper()
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %d messages, want %d", len(got), len(want))
|
||||
}
|
||||
for i, m := range want {
|
||||
if !bytes.Equal(got[i], m) {
|
||||
t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(got[i]), len(m))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildReceiver(n int) (func([]byte), <-chan struct{}, func() [][]byte) {
|
||||
var mu sync.Mutex
|
||||
var recv [][]byte
|
||||
done := make(chan struct{})
|
||||
cb := func(msg []byte) {
|
||||
mu.Lock()
|
||||
recv = append(recv, append([]byte(nil), msg...))
|
||||
count := len(recv)
|
||||
mu.Unlock()
|
||||
if count == n {
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
get := func() [][]byte {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return recv
|
||||
}
|
||||
return cb, done, get
|
||||
}
|
||||
|
||||
// TestKCPLoopback runs two KCP runtimes back-to-back through an in-memory
|
||||
// pipe simulating a perfect carrier. Verifies that messages survive the
|
||||
// KCP layer with their boundaries intact.
|
||||
func TestKCPLoopback(t *testing.T) {
|
||||
msgs := [][]byte{
|
||||
[]byte("hello"),
|
||||
bytes.Repeat([]byte("x"), 1000),
|
||||
bytes.Repeat([]byte("y"), 20000),
|
||||
}
|
||||
|
||||
a2b := make(chan []byte, 256)
|
||||
b2a := make(chan []byte, 256)
|
||||
|
||||
var bRecvMu sync.Mutex
|
||||
var bRecv [][]byte
|
||||
doneB := make(chan struct{})
|
||||
cb, doneB, getRecv := buildReceiver(len(msgs))
|
||||
|
||||
rtA, err := startKCP(a2b, nil)
|
||||
if err != nil {
|
||||
@@ -24,50 +72,18 @@ func TestKCPLoopback(t *testing.T) {
|
||||
}
|
||||
defer rtA.close()
|
||||
|
||||
rtB, err := startKCP(b2a, func(msg []byte) {
|
||||
bRecvMu.Lock()
|
||||
bRecv = append(bRecv, append([]byte(nil), msg...))
|
||||
n := len(bRecv)
|
||||
bRecvMu.Unlock()
|
||||
if n == 3 {
|
||||
close(doneB)
|
||||
}
|
||||
})
|
||||
rtB, err := startKCP(b2a, cb)
|
||||
if err != nil {
|
||||
t.Fatalf("startKCP B: %v", err)
|
||||
}
|
||||
defer rtB.close()
|
||||
|
||||
// Pump packets between the two runtimes.
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case pkt := <-a2b:
|
||||
rtB.deliver(pkt)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case pkt := <-b2a:
|
||||
rtA.deliver(pkt)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go pumpPackets(stop, a2b, rtB)
|
||||
go pumpPackets(stop, b2a, rtA)
|
||||
|
||||
msgs := [][]byte{
|
||||
[]byte("hello"),
|
||||
bytes.Repeat([]byte("x"), 1000),
|
||||
bytes.Repeat([]byte("y"), 20000),
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if err := rtA.send(m); err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
@@ -80,21 +96,10 @@ func TestKCPLoopback(t *testing.T) {
|
||||
t.Fatal("timeout waiting for messages")
|
||||
}
|
||||
|
||||
bRecvMu.Lock()
|
||||
defer bRecvMu.Unlock()
|
||||
if len(bRecv) != len(msgs) {
|
||||
t.Fatalf("got %d messages, want %d", len(bRecv), len(msgs))
|
||||
}
|
||||
for i, m := range msgs {
|
||||
if !bytes.Equal(bRecv[i], m) {
|
||||
t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(bRecv[i]), len(m))
|
||||
}
|
||||
}
|
||||
checkMessages(t, getRecv(), msgs)
|
||||
}
|
||||
|
||||
func TestVP8KeepaliveDoesNotLookLikeKCP(t *testing.T) {
|
||||
// Keepalive frames must not be mistaken for KCP packets by the receive
|
||||
// path; otherwise the KCP stack would constantly chew on garbage.
|
||||
if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpMagic {
|
||||
t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpMagic)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user