fix: fix all golangci errors

This commit is contained in:
zarazaex69
2026-05-03 06:10:48 +03:00
parent 6183233eeb
commit dd606ddfb2
25 changed files with 1072 additions and 760 deletions

View File

@@ -3,6 +3,7 @@ package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
@@ -16,6 +17,9 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/names"
)
// ErrDataDirRequired is returned when no data directory is specified.
var ErrDataDirRequired = errors.New("data directory required (use -data data)")
type config struct {
mode string
link string
@@ -59,11 +63,11 @@ func run() error {
configureLogging(cfg.debug)
if err := session.Validate(toSessionConfig(cfg)); err != nil {
return err
return fmt.Errorf("validate config: %w", err)
}
if cfg.dataDir == "" {
return fmt.Errorf("data directory required (use -data data)")
return ErrDataDirRequired
}
dataDir, err := resolveDataDir(cfg.dataDir)
@@ -119,10 +123,13 @@ func parseFlags() config {
flag.StringVar(&cfg.videoBitrate, "video-bitrate", "", "Video bitrate (videochannel only)")
flag.StringVar(&cfg.videoHW, "video-hw", "", "Hardware acceleration (none, nvenc)")
flag.IntVar(&cfg.videoQRSize, "video-qr-size", 0, "Video QR code fragment size (videochannel only)")
flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low", "QR error correction: low (7%), medium (15%), high (25%), highest (30%)")
flag.StringVar(&cfg.videoQRRecovery, "video-qr-recovery", "low",
"QR error correction: low (7%), medium (15%), high (25%), highest (30%)")
flag.StringVar(&cfg.videoCodec, "video-codec", "qrcode", "Visual codec: qrcode or tile")
flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0, "Tile module size in pixels 1..270 (videochannel tile only, default 4)")
flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0, "Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)")
flag.IntVar(&cfg.videoTileModule, "video-tile-module", 0,
"Tile module size in pixels 1..270 (videochannel tile only, default 4)")
flag.IntVar(&cfg.videoTileRS, "video-tile-rs", 0,
"Tile Reed-Solomon parity percent 0..200 (videochannel tile only, default 20)")
flag.IntVar(&cfg.vp8FPS, "vp8-fps", 0, "VP8 frames per second (vp8channel only, default 25)")
flag.IntVar(&cfg.vp8BatchSize, "vp8-batch", 0, "VP8 frames per tick (vp8channel only, default 1)")
flag.Parse()

View File

@@ -19,13 +19,19 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/transport/vp8channel"
)
const (
modeSRV = "srv"
modeCNC = "cnc"
)
var (
// ErrRoomIDRequired indicates that a room id is required for the selected carrier.
ErrRoomIDRequired = errors.New("room ID required (use -id <id>)")
// ErrModeRequired indicates that mode is not one of the supported values.
ErrModeRequired = errors.New("mode required (use -mode srv or -mode cnc)")
// ErrCarrierRequired indicates that no carrier was selected.
ErrCarrierRequired = errors.New("carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)")
ErrCarrierRequired = errors.New(
"carrier required (use -carrier telemost, -carrier jazz or -carrier wbstream)")
// ErrUnsupportedCarrier indicates that carrier is not registered.
ErrUnsupportedCarrier = errors.New("unsupported carrier")
// ErrUnsupportedLink indicates that link is not registered.
@@ -36,26 +42,40 @@ var (
// ErrLinkRequired indicates that link is not provided.
ErrLinkRequired = errors.New("link required (use -link direct)")
// ErrTransportRequired indicates that transport is not provided.
ErrTransportRequired = errors.New("transport required (use -transport datachannel, -transport videochannel, -transport seichannel or -transport vp8channel)")
ErrTransportRequired = errors.New(
"transport required (use -transport datachannel, -transport videochannel, " +
"-transport seichannel or -transport vp8channel)")
// ErrKeyRequired indicates that encryption key is not provided.
ErrKeyRequired = errors.New("key required (use -key <hex>)")
// ErrDNSServerRequired indicates that dns server is not provided.
ErrDNSServerRequired = errors.New("dns server required (use -dns 1.1.1.1:53)")
// Videochannel errors
// ErrVideoWidthRequired indicates that video width is required for videochannel.
ErrVideoWidthRequired = errors.New("video width required for videochannel (use -video-w)")
// ErrVideoHeightRequired indicates that video height is required for videochannel.
ErrVideoHeightRequired = errors.New("video height required for videochannel (use -video-h)")
// ErrVideoFPSRequired indicates that video fps is required for videochannel.
ErrVideoFPSRequired = errors.New("video fps required for videochannel (use -video-fps)")
ErrVideoBitrateRequired = errors.New("video bitrate required for videochannel (use -video-bitrate)")
ErrVideoHWRequired = errors.New("video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
ErrVideoCodecInvalid = errors.New("invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
// ErrVideoBitrateRequired indicates that video bitrate is required for videochannel.
ErrVideoBitrateRequired = errors.New(
"video bitrate required for videochannel (use -video-bitrate)")
// ErrVideoHWRequired indicates that video hardware acceleration is required.
ErrVideoHWRequired = errors.New(
"video hardware acceleration required for videochannel (use -video-hw none/nvenc)")
// ErrVideoCodecInvalid indicates that the video codec is not valid.
ErrVideoCodecInvalid = errors.New(
"invalid video codec for videochannel (use -video-codec qrcode or -video-codec tile)")
// ErrTileCodecDimensions indicates that tile codec requires 1080x1080 dimensions.
ErrTileCodecDimensions = errors.New("tile codec requires -video-w 1080 -video-h 1080")
// VP8channel errors
// ErrVP8FPSRequired indicates that vp8 fps is required for vp8channel.
ErrVP8FPSRequired = errors.New("vp8 fps required for vp8channel (use -vp8-fps)")
// ErrVP8BatchSizeRequired indicates that vp8 batch size is required for vp8channel.
ErrVP8BatchSizeRequired = errors.New("vp8 batch size required for vp8channel (use -vp8-batch)")
// CNC errors
// ErrSOCKSHostRequired indicates that socks host is required for cnc mode.
ErrSOCKSHostRequired = errors.New("socks host required for cnc mode (use -socks-host)")
// ErrSOCKSPortRequired indicates that socks port is required for cnc mode.
ErrSOCKSPortRequired = errors.New("socks port required for cnc mode (use -socks-port)")
)
@@ -98,74 +118,105 @@ func RegisterDefaults() {
// Validate verifies that the runtime config refers to registered components and all required fields are present.
func Validate(cfg Config) error {
availableCarriers := carrier.Available()
validCarrier := false
for _, c := range availableCarriers {
if cfg.Carrier == c {
validCarrier = true
break
if err := validateMode(cfg); err != nil {
return err
}
if err := validateCarrier(cfg); err != nil {
return err
}
if err := validateLink(cfg); err != nil {
return err
}
if err := validateTransportRegistration(cfg); err != nil {
return err
}
if err := validateCommon(cfg); err != nil {
return err
}
if err := validateTransportConfig(cfg); err != nil {
return err
}
return validateModeConfig(cfg)
}
availableTransports := transport.Available()
validTransport := false
for _, t := range availableTransports {
if cfg.Transport == t {
validTransport = true
break
}
}
availableLinks := link.Available()
validLink := false
for _, l := range availableLinks {
if cfg.Link == l {
validLink = true
break
}
}
if cfg.Mode == "" {
func validateMode(cfg Config) error {
if cfg.Mode == "" || (cfg.Mode != modeSRV && cfg.Mode != modeCNC) {
return ErrModeRequired
}
if cfg.Mode != "srv" && cfg.Mode != "cnc" {
return ErrModeRequired
return nil
}
func validateCarrier(cfg Config) error {
if cfg.Carrier == "" {
return ErrCarrierRequired
}
if !validCarrier {
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, availableCarriers)
for _, c := range carrier.Available() {
if cfg.Carrier == c {
return nil
}
}
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, carrier.Available())
}
func validateLink(cfg Config) error {
if cfg.Link == "" {
return ErrLinkRequired
}
if !validLink {
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, availableLinks)
for _, l := range link.Available() {
if cfg.Link == l {
return nil
}
}
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, link.Available())
}
func validateTransportRegistration(cfg Config) error {
if cfg.Transport == "" {
return ErrTransportRequired
}
if !validTransport {
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, availableTransports)
for _, t := range transport.Available() {
if cfg.Transport == t {
return nil
}
}
return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, transport.Available())
}
func validateCommon(cfg Config) error {
if cfg.RoomID == "" && cfg.Carrier != "jazz" {
return ErrRoomIDRequired
}
if cfg.KeyHex == "" {
return ErrKeyRequired
}
if cfg.DNSServer == "" {
return ErrDNSServerRequired
}
return nil
}
if cfg.Transport == "videochannel" {
func validateTransportConfig(cfg Config) error {
switch cfg.Transport {
case "videochannel":
return validateVideoChannel(cfg)
case "vp8channel":
return validateVP8Channel(cfg)
default:
return nil
}
}
func validateVideoCodec(cfg Config) error {
if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" {
return ErrVideoCodecInvalid
}
if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) {
return ErrTileCodecDimensions
}
return nil
}
func validateVideoChannel(cfg Config) error {
if cfg.VideoWidth == 0 {
return ErrVideoWidthRequired
}
@@ -181,32 +232,29 @@ func Validate(cfg Config) error {
if cfg.VideoHW == "" {
return ErrVideoHWRequired
}
if cfg.VideoCodec != "" && cfg.VideoCodec != "qrcode" && cfg.VideoCodec != "tile" {
return ErrVideoCodecInvalid
}
if cfg.VideoCodec == "tile" && (cfg.VideoWidth != 1080 || cfg.VideoHeight != 1080) {
return errors.New("tile codec requires -video-w 1080 -video-h 1080")
}
return validateVideoCodec(cfg)
}
if cfg.Transport == "vp8channel" {
func validateVP8Channel(cfg Config) error {
if cfg.VP8FPS == 0 {
return ErrVP8FPSRequired
}
if cfg.VP8BatchSize == 0 {
return ErrVP8BatchSizeRequired
}
return nil
}
if cfg.Mode == "cnc" {
func validateModeConfig(cfg Config) error {
if cfg.Mode != modeCNC {
return nil
}
if cfg.SOCKSHost == "" {
return ErrSOCKSHostRequired
}
if cfg.SOCKSPort == 0 {
return ErrSOCKSPortRequired
}
}
return nil
}
@@ -215,8 +263,8 @@ func Run(ctx context.Context, cfg Config) error {
roomURL := buildRoomURL(cfg.Carrier, cfg.RoomID)
switch cfg.Mode {
case "srv":
return server.Run(
case modeSRV:
if err := server.Run(
ctx,
cfg.Link,
cfg.Transport,
@@ -238,9 +286,12 @@ func Run(ctx context.Context, cfg Config) error {
cfg.VideoTileRS,
cfg.VP8FPS,
cfg.VP8BatchSize,
)
case "cnc":
return client.Run(
); err != nil {
return fmt.Errorf("server: %w", err)
}
return nil
case modeCNC:
if err := client.Run(
ctx,
cfg.Link,
cfg.Transport,
@@ -263,7 +314,10 @@ func Run(ctx context.Context, cfg Config) error {
cfg.VideoTileRS,
cfg.VP8FPS,
cfg.VP8BatchSize,
)
); err != nil {
return fmt.Errorf("client: %w", err)
}
return nil
default:
return ErrModeRequired
}

View File

@@ -2,6 +2,7 @@ package carrier
import (
"context"
"fmt"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
@@ -32,6 +33,11 @@ type VideoTrack interface {
SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver))
}
type videoTrackProvider interface {
provider.Provider
provider.VideoTrackCapable
}
type legacySession struct {
provider provider.Provider
}
@@ -39,7 +45,7 @@ type legacySession struct {
// Capabilities reports the transport primitives supported by the legacy carrier.
func (s *legacySession) Capabilities() Capabilities {
caps := Capabilities{ByteStream: true}
_, caps.VideoTrack = s.provider.(provider.VideoTrackCapable)
_, caps.VideoTrack = s.provider.(videoTrackProvider)
return caps
}
@@ -50,20 +56,35 @@ func (s *legacySession) OpenByteStream() (ByteStream, error) {
// OpenVideoTrack adapts a legacy provider to the generic video track capability.
func (s *legacySession) OpenVideoTrack() (VideoTrack, error) {
publisher, ok := s.provider.(provider.VideoTrackCapable)
vtp, ok := s.provider.(videoTrackProvider)
if !ok {
return nil, ErrVideoTrackUnsupported
}
return &legacyVideoTrack{provider: publisher}, nil
return &legacyVideoTrack{provider: vtp}, nil
}
type legacyByteStream struct {
provider provider.Provider
}
func (p *legacyByteStream) Connect(ctx context.Context) error { return p.provider.Connect(ctx) }
func (p *legacyByteStream) Send(data []byte) error { return p.provider.Send(data) }
func (p *legacyByteStream) Close() error { return p.provider.Close() }
func (p *legacyByteStream) Connect(ctx context.Context) error {
if err := p.provider.Connect(ctx); err != nil {
return fmt.Errorf("connect: %w", err)
}
return nil
}
func (p *legacyByteStream) Send(data []byte) error {
if err := p.provider.Send(data); err != nil {
return fmt.Errorf("send: %w", err)
}
return nil
}
func (p *legacyByteStream) Close() error {
if err := p.provider.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
return nil
}
func (p *legacyByteStream) SetReconnectCallback(cb func()) {
p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) {
@@ -81,31 +102,38 @@ func (p *legacyByteStream) WatchConnection(ctx context.Context) {
func (p *legacyByteStream) CanSend() bool { return p.provider.CanSend() }
type legacyVideoTrack struct {
provider provider.VideoTrackCapable
provider videoTrackProvider
}
func (v *legacyVideoTrack) Connect(ctx context.Context) error {
return v.provider.(provider.Provider).Connect(ctx)
if err := v.provider.Connect(ctx); err != nil {
return fmt.Errorf("connect: %w", err)
}
func (v *legacyVideoTrack) Close() error { return v.provider.(provider.Provider).Close() }
func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) {
v.provider.(provider.Provider).SetShouldReconnect(fn)
return nil
}
func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) {
v.provider.(provider.Provider).SetEndedCallback(cb)
func (v *legacyVideoTrack) Close() error {
if err := v.provider.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
return nil
}
func (v *legacyVideoTrack) SetShouldReconnect(fn func() bool) { v.provider.SetShouldReconnect(fn) }
func (v *legacyVideoTrack) SetEndedCallback(cb func(string)) { v.provider.SetEndedCallback(cb) }
func (v *legacyVideoTrack) WatchConnection(ctx context.Context) {
v.provider.(provider.Provider).WatchConnection(ctx)
v.provider.WatchConnection(ctx)
}
func (v *legacyVideoTrack) CanSend() bool { return v.provider.(provider.Provider).CanSend() }
func (v *legacyVideoTrack) CanSend() bool { return v.provider.CanSend() }
func (v *legacyVideoTrack) AddTrack(track webrtc.TrackLocal) error {
return v.provider.AddVideoTrack(track)
if err := v.provider.AddVideoTrack(track); err != nil {
return fmt.Errorf("add track: %w", err)
}
return nil
}
func (v *legacyVideoTrack) SetTrackHandler(cb func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) {
v.provider.SetVideoTrackHandler(cb)
}
func (v *legacyVideoTrack) SetReconnectCallback(cb func()) {
v.provider.(provider.Provider).SetReconnectCallback(func(_ *webrtc.DataChannel) {
v.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) {
if cb != nil {
cb()
}

View File

@@ -51,6 +51,7 @@ type Config struct {
// Factory creates a new carrier session.
type Factory func(ctx context.Context, cfg Config) (Session, error)
//nolint:gochecknoglobals
var registry = make(map[string]Factory)
// Register adds a carrier factory to the registry.

View File

@@ -26,6 +26,16 @@ var (
ErrConnectFailed = errors.New("tunnel connection failed")
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrInvalidSOCKSVersion is returned when the SOCKS version is not 5.
ErrInvalidSOCKSVersion = errors.New("invalid socks version")
// ErrUnsupportedSOCKSCommand is returned for unsupported SOCKS commands.
ErrUnsupportedSOCKSCommand = errors.New("unsupported socks command")
// ErrUnsupportedAddressType is returned for unsupported SOCKS address types.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrRemoteNotReady is returned when the server-side stream fails to signal readiness.
ErrRemoteNotReady = errors.New("remote not ready")
)
// Client handles local SOCKS5 connections and tunnels them to the server.
@@ -63,7 +73,13 @@ func Run(
vp8FPS int,
vp8BatchSize int,
) error {
return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize)
return RunWithReady(
ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr,
dnsServer, socksUser, socksPass, nil,
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
vp8FPS, vp8BatchSize,
)
}
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
@@ -118,7 +134,7 @@ func RunWithReady(
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
}
defer listener.Close()
defer func() { _ = listener.Close() }()
logger.Infof("SOCKS5 server listening on %s", localAddr)
@@ -126,17 +142,10 @@ func RunWithReady(
onReady()
}
errCh := make(chan error, 1)
go func() {
errCh <- c.acceptLoop(runCtx, listener)
}()
go c.acceptLoop(runCtx, listener)
select {
case <-runCtx.Done():
<-runCtx.Done()
return nil
case err := <-errCh:
return err
}
}
func (c *Client) bringUpLink(
@@ -227,8 +236,6 @@ func (c *Client) handleReconnect() {
c.conn = nil
}
c.sessMu.Unlock()
// New SOCKS5 connections will fail until the link comes back up; the
// caller will reissue them. Existing streams die with the smux session.
c.conn = muxconn.New(c.ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
@@ -260,7 +267,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
return nil, fmt.Errorf("%w: got %d", ErrKeySize, len(key))
}
cipher, err := crypto.NewCipher(string(key))
@@ -279,13 +286,13 @@ func (c *Client) onData(data []byte) {
}
}
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error {
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
return
default:
logger.Warnf("Accept error: %v", err)
continue
@@ -295,8 +302,8 @@ func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) error {
}
}
func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
defer conn.Close()
func (c *Client) handleSocks5(_ context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
return
@@ -315,38 +322,25 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
return
}
c.tunnel(conn, sess, targetAddr, targetPort)
}
func (c *Client) tunnel(conn net.Conn, sess *smux.Session, targetAddr string, targetPort int) {
stream, err := sess.OpenStream()
if err != nil {
logger.Warnf("OpenStream failed: %v", err)
_, _ = conn.Write(replyHostUnreachable())
return
}
defer stream.Close()
defer func() { _ = stream.Close() }()
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
connectReq, _ := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err)
if err := c.sendConnectRequest(stream, targetAddr, targetPort); err != nil {
logger.Warnf("sid=%d connect failed: %v", stream.ID(), err)
_, _ = conn.Write(replyHostUnreachable())
return
}
_ = stream.SetWriteDeadline(time.Time{})
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack)
_, _ = conn.Write(replyHostUnreachable())
return
}
_ = stream.SetReadDeadline(time.Time{})
if _, err := conn.Write(replySuccess()); err != nil {
return
@@ -357,24 +351,47 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
_ = stream.Close()
}()
_, _ = io.Copy(conn, stream)
}
_ = ctx // keep signature
func (c *Client) sendConnectRequest(stream *smux.Stream, targetAddr string, targetPort int) error {
connectReq, err := json.Marshal(map[string]any{
"cmd": "connect",
"addr": targetAddr,
"port": targetPort,
})
if err != nil {
return fmt.Errorf("sid=%d marshal connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := stream.Write(connectReq); err != nil {
return fmt.Errorf("sid=%d write connect req: %w", stream.ID(), err)
}
_ = stream.SetWriteDeadline(time.Time{})
ack := make([]byte, 1)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
return fmt.Errorf("sid=%d: %w (read_err=%w ack=%v)", stream.ID(), ErrRemoteNotReady, err, ack)
}
_ = stream.SetReadDeadline(time.Time{})
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return err
return fmt.Errorf("read socks5 header: %w", err)
}
if buf[0] != 5 {
return fmt.Errorf("invalid socks version: %d", buf[0])
return fmt.Errorf("%w: %d", ErrInvalidSOCKSVersion, buf[0])
}
methods := make([]byte, buf[1])
if _, err := io.ReadFull(conn, methods); err != nil {
return err
return fmt.Errorf("read socks5 methods: %w", err)
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return err
return fmt.Errorf("write socks5 auth: %w", err)
}
return nil
}
@@ -382,43 +399,49 @@ func (c *Client) socks5Handshake(conn net.Conn) error {
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return "", 0, err
return "", 0, fmt.Errorf("read socks5 request: %w", err)
}
if header[1] != 1 {
return "", 0, fmt.Errorf("unsupported socks command: %d", header[1])
return "", 0, fmt.Errorf("%w: %d", ErrUnsupportedSOCKSCommand, header[1])
}
var addr string
switch header[3] {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
addr, err := c.readSocks5Addr(conn, header[3])
if err != nil {
return "", 0, err
}
addr = net.IP(buf).String()
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", 0, err
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, err
}
addr = string(buf)
default:
return "", 0, fmt.Errorf("unsupported address type: %d", header[3])
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, err
return "", 0, fmt.Errorf("read socks5 port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 ipv4: %w", err)
}
return net.IP(buf).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read socks5 domain len: %w", err)
}
buf := make([]byte, lenBuf[0])
if _, err := io.ReadFull(conn, buf); err != nil {
return "", fmt.Errorf("read socks5 domain: %w", err)
}
return string(buf), nil
default:
return "", fmt.Errorf("%w: %d", ErrUnsupportedAddressType, addrType)
}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}

View File

@@ -43,9 +43,27 @@ func New(ctx context.Context, cfg link.Config) (link.Link, error) {
return &directLink{transport: tr}, nil
}
func (d *directLink) Connect(ctx context.Context) error { return d.transport.Connect(ctx) }
func (d *directLink) Send(data []byte) error { return d.transport.Send(data) }
func (d *directLink) Close() error { return d.transport.Close() }
func (d *directLink) Connect(ctx context.Context) error {
if err := d.transport.Connect(ctx); err != nil {
return fmt.Errorf("transport connect: %w", err)
}
return nil
}
func (d *directLink) Send(data []byte) error {
if err := d.transport.Send(data); err != nil {
return fmt.Errorf("transport send: %w", err)
}
return nil
}
func (d *directLink) Close() error {
if err := d.transport.Close(); err != nil {
return fmt.Errorf("transport close: %w", err)
}
return nil
}
func (d *directLink) SetReconnectCallback(cb func()) { d.transport.SetReconnectCallback(cb) }
func (d *directLink) SetShouldReconnect(fn func() bool) { d.transport.SetShouldReconnect(fn) }
func (d *directLink) SetEndedCallback(cb func(string)) { d.transport.SetEndedCallback(cb) }

View File

@@ -50,6 +50,7 @@ type Config struct {
// Factory creates a link instance.
type Factory func(ctx context.Context, cfg Config) (Link, error)
//nolint:gochecknoglobals
var registry = make(map[string]Factory)
// Register adds a link factory to the registry.

View File

@@ -17,6 +17,7 @@ package muxconn
import (
"errors"
"fmt"
"io"
"sync"
"time"
@@ -92,10 +93,10 @@ func (c *Conn) Write(p []byte) (int, error) {
enc, err := c.cipher.Encrypt(p)
if err != nil {
return 0, err
return 0, fmt.Errorf("encrypt: %w", err)
}
if err := c.ln.Send(enc); err != nil {
return 0, err
return 0, fmt.Errorf("send: %w", err)
}
return len(p), nil
}

View File

@@ -24,6 +24,13 @@ const (
sendDelay = 2 * time.Millisecond
)
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
// ErrSubscriberMediaTimeout is returned when the subscriber media is not ready within the timeout period.
ErrSubscriberMediaTimeout = errors.New("subscriber media timeout")
)
// Peer represents a SaluteJazz WebRTC connection.
type Peer struct {
name string
@@ -135,23 +142,23 @@ func (p *Peer) attachPendingVideoTracks() error {
return nil
}
// Connect starts the WebRTC connection process.
func (p *Peer) Connect(ctx context.Context) error {
p.closed.Store(false)
p.resetMediaState()
config := webrtc.Configuration{
func defaultWebRTCConfig() webrtc.Configuration {
return webrtc.Configuration{
ICEServers: []webrtc.ICEServer{},
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
BundlePolicy: webrtc.BundlePolicyMaxBundle,
}
settingEngine := webrtc.SettingEngine{}
if protect.Protector != nil {
settingEngine.SetICEProxyDialer(protect.NewProxyDialer())
}
api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
func (p *Peer) buildAPI() *webrtc.API {
se := webrtc.SettingEngine{}
if protect.Protector != nil {
se.SetICEProxyDialer(protect.NewProxyDialer())
}
return webrtc.NewAPI(webrtc.WithSettingEngine(se))
}
func (p *Peer) createPeerConnections(api *webrtc.API, config webrtc.Configuration) error {
var err error
p.pcSub, err = api.NewPeerConnection(config)
if err != nil {
@@ -162,7 +169,6 @@ func (p *Peer) Connect(ctx context.Context) error {
if track.Kind() != webrtc.RTPCodecTypeVideo {
return
}
if cb := p.videoTrackHandler(); cb != nil {
cb(track, receiver)
}
@@ -173,28 +179,63 @@ func (p *Peer) Connect(ctx context.Context) error {
return fmt.Errorf("create publisher pc: %w", err)
}
p.pcPub.OnConnectionStateChange(p.onPublisherConnectionStateChange)
return nil
}
func (p *Peer) createDataChannel() (chan struct{}, error) {
var err error
p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{
Ordered: func() *bool { v := true; return &v }(),
})
if err != nil {
return nil, fmt.Errorf("create datachannel: %w", err)
}
dcReady := make(chan struct{})
p.setupDataChannelHandlers(dcReady)
return dcReady, nil
}
func (p *Peer) waitForReady(ctx context.Context, dcReady chan struct{}) error {
if dcReady != nil {
select {
case <-dcReady:
return nil
case <-time.After(30 * time.Second):
return provider.ErrDataChannelTimeout
case <-ctx.Done():
return fmt.Errorf("connect cancelled: %w", ctx.Err())
}
}
return p.waitForMediaReady(ctx, 30*time.Second)
}
// Connect starts the WebRTC connection process.
func (p *Peer) Connect(ctx context.Context) error {
p.closed.Store(false)
p.resetMediaState()
api := p.buildAPI()
config := defaultWebRTCConfig()
if err := p.createPeerConnections(api, config); err != nil {
return err
}
if err := p.attachPendingVideoTracks(); err != nil {
return err
}
var dcReady chan struct{}
if p.onData != nil {
p.dc, err = p.pcPub.CreateDataChannel("_reliable", &webrtc.DataChannelInit{
Ordered: func() *bool { v := true; return &v }(),
})
var err error
dcReady, err = p.createDataChannel()
if err != nil {
return fmt.Errorf("create datachannel: %w", err)
return err
}
dcReady = make(chan struct{})
p.setupDataChannelHandlers(dcReady)
}
if err := p.dialWebSocket(); err != nil {
return err
}
if err := p.sendJoin(); err != nil {
return err
}
@@ -205,18 +246,7 @@ func (p *Peer) Connect(ctx context.Context) error {
p.handleSignaling(ctx)
}()
if p.onData != nil {
select {
case <-dcReady:
return nil
case <-time.After(30 * time.Second):
return provider.ErrDataChannelTimeout
case <-ctx.Done():
return fmt.Errorf("connect cancelled: %w", ctx.Err())
}
}
return p.waitForMediaReady(ctx, 30*time.Second)
return p.waitForReady(ctx, dcReady)
}
func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) error {
@@ -226,7 +256,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err
select {
case <-p.subscriberConn:
case <-timer.C:
return fmt.Errorf("subscriber media timeout")
return ErrSubscriberMediaTimeout
case <-ctx.Done():
return fmt.Errorf("connect cancelled: %w", ctx.Err())
}
@@ -320,30 +350,38 @@ func (p *Peer) setupDataChannelHandlers(dcReady chan struct{}) {
}
func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
p.subscriberReady.Store(true)
closeSignal(p.subscriberConn)
} else if state == webrtc.PeerConnectionStateDisconnected ||
state == webrtc.PeerConnectionStateFailed ||
state == webrtc.PeerConnectionStateClosed {
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed:
p.subscriberReady.Store(false)
if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) {
if !p.closed.Load() {
p.queueReconnect()
}
case webrtc.PeerConnectionStateClosed:
p.subscriberReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
}
func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
p.publisherReady.Store(true)
closeSignal(p.publisherConn)
} else if state == webrtc.PeerConnectionStateDisconnected ||
state == webrtc.PeerConnectionStateFailed ||
state == webrtc.PeerConnectionStateClosed {
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed:
p.publisherReady.Store(false)
if !p.closed.Load() && (state == webrtc.PeerConnectionStateDisconnected || state == webrtc.PeerConnectionStateFailed) {
if !p.closed.Load() {
p.queueReconnect()
}
case webrtc.PeerConnectionStateClosed:
p.publisherReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
}
@@ -651,11 +689,6 @@ func (p *Peer) Close() error {
return nil
}
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
)
// AddVideoTrack adds a video track to the publisher peer connection.
func (p *Peer) AddVideoTrack(track webrtc.TrackLocal) error {
p.videoTrackMu.Lock()

View File

@@ -22,8 +22,6 @@ var (
)
// Provider defines the standard interface for WebRTC connection handlers.
//
//nolint:interfacebloat // All methods are necessary for provider abstraction.
type Provider interface {
Connect(ctx context.Context) error
Send(data []byte) error

View File

@@ -42,6 +42,8 @@ var (
ErrSessionClosed = errors.New("session closed")
// ErrPeerClosed is returned when the peer is closed.
ErrPeerClosed = errors.New("peer closed")
// ErrSubscriberMediaTimeout is returned when subscriber media is not ready within the timeout period.
ErrSubscriberMediaTimeout = errors.New("subscriber media timeout")
)
// TrafficShape defines the parameters for outgoing traffic control.
@@ -288,7 +290,7 @@ func (p *Peer) waitForMediaReady(ctx context.Context, timeout time.Duration) err
select {
case <-p.subscriberConn:
case <-timer.C:
return fmt.Errorf("subscriber media timeout")
return ErrSubscriberMediaTimeout
case <-ctx.Done():
return fmt.Errorf("connect context cancelled: %w", ctx.Err())
}
@@ -314,7 +316,8 @@ func (p *Peer) setupPeerConnections(config webrtc.Configuration) error {
return
}
logger.Infof("telemost remote video track: codec=%s stream=%s track=%s", track.Codec().MimeType, track.StreamID(), track.ID())
logger.Infof("telemost remote video track: codec=%s stream=%s track=%s",
track.Codec().MimeType, track.StreamID(), track.ID())
if cb := p.videoTrackHandler(); cb != nil {
cb(track, receiver)
@@ -342,29 +345,35 @@ func (p *Peer) onConnectionStateChange(state webrtc.PeerConnectionState) {
func (p *Peer) onSubscriberConnectionStateChange(state webrtc.PeerConnectionState) {
logger.Debugf("telemost subscriber state: %s", state.String())
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
p.subscriberReady.Store(true)
closeSignal(p.subscriberConn)
} else if state == webrtc.PeerConnectionStateDisconnected ||
state == webrtc.PeerConnectionStateFailed ||
state == webrtc.PeerConnectionStateClosed {
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed,
webrtc.PeerConnectionStateClosed:
p.subscriberReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
p.onConnectionStateChange(state)
}
func (p *Peer) onPublisherConnectionStateChange(state webrtc.PeerConnectionState) {
logger.Debugf("telemost publisher state: %s", state.String())
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
p.publisherReady.Store(true)
closeSignal(p.publisherConn)
} else if state == webrtc.PeerConnectionStateDisconnected ||
state == webrtc.PeerConnectionStateFailed ||
state == webrtc.PeerConnectionStateClosed {
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed,
webrtc.PeerConnectionStateClosed:
p.publisherReady.Store(false)
case webrtc.PeerConnectionStateUnknown,
webrtc.PeerConnectionStateNew,
webrtc.PeerConnectionStateConnecting:
}
p.onConnectionStateChange(state)
}
@@ -656,7 +665,7 @@ func (p *Peer) sendSetSlots() error {
p.wsMu.Lock()
defer p.wsMu.Unlock()
return p.ws.WriteJSON(map[string]interface{}{
if err := p.ws.WriteJSON(map[string]interface{}{
"uid": uuid.New().String(),
"setSlots": map[string]interface{}{
"slots": []map[string]int{
@@ -670,7 +679,52 @@ func (p *Peer) sendSetSlots() error {
"selfViewVisibility": "ON_LOADING_THEN_SHOW",
"gridConfig": map[string]interface{}{},
},
})
}); err != nil {
return fmt.Errorf("write set slots: %w", err)
}
return nil
}
func isNonTURNURL(url string) bool {
return url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:")
}
func parseICEURLs(server map[string]interface{}) []string {
var urls []string
switch rawURLs := server["urls"].(type) {
case []interface{}:
for _, rawURL := range rawURLs {
if url, ok := rawURL.(string); ok && isNonTURNURL(url) {
urls = append(urls, url)
}
}
case []string:
for _, url := range rawURLs {
if isNonTURNURL(url) {
urls = append(urls, url)
}
}
}
return urls
}
func parseICEServer(rawServer interface{}) (webrtc.ICEServer, bool) {
server, ok := rawServer.(map[string]interface{})
if !ok {
return webrtc.ICEServer{}, false
}
urls := parseICEURLs(server)
if len(urls) == 0 {
return webrtc.ICEServer{}, false
}
ice := webrtc.ICEServer{URLs: urls}
if username, ok := server["username"].(string); ok {
ice.Username = username
}
if credential, ok := server["credential"].(string); ok {
ice.Credential = credential
}
return ice, true
}
func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) {
@@ -686,40 +740,10 @@ func (p *Peer) applyServerHelloConfig(serverHello map[string]interface{}) {
iceServers := make([]webrtc.ICEServer, 0, len(rawServers))
for _, rawServer := range rawServers {
server, ok := rawServer.(map[string]interface{})
if !ok {
continue
}
var urls []string
switch rawURLs := server["urls"].(type) {
case []interface{}:
for _, rawURL := range rawURLs {
if url, ok := rawURL.(string); ok && url != "" && !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") {
urls = append(urls, url)
}
}
case []string:
for _, url := range rawURLs {
if !strings.HasPrefix(url, "turn:") && !strings.HasPrefix(url, "turns:") {
urls = append(urls, url)
}
}
}
if len(urls) == 0 {
continue
}
ice := webrtc.ICEServer{URLs: urls}
if username, ok := server["username"].(string); ok {
ice.Username = username
}
if credential, ok := server["credential"].(string); ok {
ice.Credential = credential
}
if ice, ok := parseICEServer(rawServer); ok {
iceServers = append(iceServers, ice)
}
}
if len(iceServers) == 0 {
return

View File

@@ -22,6 +22,8 @@ import (
)
var (
// ErrKeyRequired is returned when no encryption key is provided.
ErrKeyRequired = errors.New("key required (use -key <hex>)")
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
@@ -100,17 +102,17 @@ func Run(
return err
}
err = s.serve(runCtx)
s.serve(runCtx)
s.shutdown()
s.wg.Wait()
return err
return nil
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
if keyHex == "" {
return nil, errors.New("key required (use -key <hex>)")
return nil, ErrKeyRequired
}
key, err := hex.DecodeString(keyHex)
@@ -252,10 +254,12 @@ func (s *Server) onData(data []byte) {
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
// The loop tolerates session bounces (reconnects) by waiting until a fresh
// session is installed instead of terminating the server.
func (s *Server) serve(ctx context.Context) error {
func (s *Server) serve(ctx context.Context) {
for {
if ctx.Err() != nil {
return nil
select {
case <-ctx.Done():
return
default:
}
s.sessMu.RLock()
@@ -264,7 +268,7 @@ func (s *Server) serve(ctx context.Context) error {
if sess == nil {
select {
case <-ctx.Done():
return nil
return
case <-time.After(50 * time.Millisecond):
continue
}
@@ -272,10 +276,10 @@ func (s *Server) serve(ctx context.Context) error {
stream, err := sess.AcceptStream()
if err != nil {
// Session is torn down (reconnect or close). If we're shutting
// down, exit; otherwise wait for a new session and retry.
if ctx.Err() != nil {
return nil
select {
case <-ctx.Done():
return
default:
}
logger.Infof("AcceptStream returned %v — waiting for new session", err)
time.Sleep(100 * time.Millisecond)
@@ -305,7 +309,7 @@ func (s *Server) shutdown() {
}
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
defer stream.Close()
defer func() { _ = stream.Close() }()
// Read the connect JSON. The client writes the whole JSON in one
// stream.Write so it usually arrives intact; tolerate fragmentation
@@ -356,7 +360,7 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)

View File

@@ -44,17 +44,26 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
// Connect starts the transport connection.
func (p *streamTransport) Connect(ctx context.Context) error {
return p.stream.Connect(ctx)
if err := p.stream.Connect(ctx); err != nil {
return fmt.Errorf("stream connect: %w", err)
}
return nil
}
// Send transmits data through the transport.
func (p *streamTransport) Send(data []byte) error {
return p.stream.Send(data)
if err := p.stream.Send(data); err != nil {
return fmt.Errorf("stream send: %w", err)
}
return nil
}
// Close terminates the transport.
func (p *streamTransport) Close() error {
return p.stream.Close()
if err := p.stream.Close(); err != nil {
return fmt.Errorf("stream close: %w", err)
}
return nil
}
// SetReconnectCallback registers reconnect handling.

View File

@@ -3,11 +3,20 @@ package seichannel
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"github.com/pion/webrtc/v4/pkg/media/h264reader"
)
var (
// ErrSEIPayloadTruncated is returned when the SEI payload is shorter than expected.
ErrSEIPayloadTruncated = errors.New("sei payload truncated")
// ErrSEIValueTruncated is returned when reading a SEI length-value runs past the buffer.
ErrSEIValueTruncated = errors.New("sei value truncated")
)
//nolint:gochecknoglobals
var (
videoSEIUUID = [16]byte{
0x5d, 0xc0, 0x3b, 0xa8,
@@ -21,19 +30,16 @@ var (
baseIDR = mustDecodeHex("6588843a2628000902e0")
)
func buildVideoAccessUnit(payload []byte) ([]byte, error) {
func buildVideoAccessUnit(payload []byte) []byte {
out := make([]byte, 0, len(baseSPS)+len(basePPS)+len(baseIDR)+64+len(payload))
out = appendStartCode(out, baseSPS)
out = appendStartCode(out, basePPS)
if len(payload) > 0 {
sei, err := buildSEINAL(payload)
if err != nil {
return nil, err
}
sei := buildSEINAL(payload)
out = appendStartCode(out, sei)
}
out = appendStartCode(out, baseIDR)
return out, nil
return out
}
func extractVideoPayloads(accessUnit []byte) ([][]byte, error) {
@@ -63,7 +69,7 @@ func extractVideoPayloads(accessUnit []byte) ([][]byte, error) {
}
}
func buildSEINAL(payload []byte) ([]byte, error) {
func buildSEINAL(payload []byte) []byte {
userData := make([]byte, 0, len(videoSEIUUID)+len(payload))
userData = append(userData, videoSEIUUID[:]...)
userData = append(userData, payload...)
@@ -74,9 +80,11 @@ func buildSEINAL(payload []byte) ([]byte, error) {
rbsp = append(rbsp, userData...)
rbsp = append(rbsp, 0x80)
out := []byte{0x06}
out = append(out, escapeRBSP(rbsp)...)
return out, nil
escaped := escapeRBSP(rbsp)
out := make([]byte, 0, 1+len(escaped))
out = append(out, 0x06)
out = append(out, escaped...)
return out
}
func extractTransportSEI(rbsp []byte) ([][]byte, error) {
@@ -101,7 +109,7 @@ func extractTransportSEI(rbsp []byte) ([][]byte, error) {
pos = next
if pos+payloadSize > len(data) {
return nil, fmt.Errorf("sei payload truncated")
return nil, ErrSEIPayloadTruncated
}
payload := data[pos : pos+payloadSize]
@@ -127,14 +135,14 @@ func appendSEIValue(dst []byte, value int) []byte {
dst = append(dst, 0xff)
value -= 0xff
}
return append(dst, byte(value))
return append(dst, byte(value)) //nolint:gosec
}
func consumeSEIValue(data []byte, pos int) (int, int, error) {
value := 0
for {
if pos >= len(data) {
return 0, pos, fmt.Errorf("sei value truncated")
return 0, pos, ErrSEIValueTruncated
}
b := int(data[pos])
pos++
@@ -170,11 +178,11 @@ func escapeRBSP(rbsp []byte) []byte {
func unescapeRBSP(rbsp []byte) []byte {
out := make([]byte, 0, len(rbsp))
for i := 0; i < len(rbsp); i++ {
if i >= 2 && rbsp[i] == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 {
for i, b := range rbsp {
if i >= 2 && b == 0x03 && rbsp[i-1] == 0x00 && rbsp[i-2] == 0x00 {
continue
}
out = append(out, rbsp[i])
out = append(out, b)
}
return out
}

View File

@@ -40,6 +40,18 @@ var (
ErrAckTimeout = errors.New("seichannel ack timeout")
// ErrTransportClosed is returned when operations are attempted on a closed transport.
ErrTransportClosed = errors.New("seichannel transport closed")
// ErrFrameTooShort is returned when the received frame is too short to decode.
ErrFrameTooShort = errors.New("frame too short")
// ErrUnexpectedMagic is returned when the frame magic bytes do not match.
ErrUnexpectedMagic = errors.New("unexpected frame magic")
// ErrUnexpectedVersion is returned when the frame protocol version does not match.
ErrUnexpectedVersion = errors.New("unexpected frame version")
// ErrAckTooShort is returned when the ack frame is shorter than expected.
ErrAckTooShort = errors.New("ack frame too short")
// ErrDataTooShort is returned when the data frame is shorter than expected.
ErrDataTooShort = errors.New("data frame too short")
// ErrUnexpectedFrameType is returned for unknown frame type bytes.
ErrUnexpectedFrameType = errors.New("unexpected frame type")
)
type transportFrame struct {
@@ -144,7 +156,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
defer cancel()
if err := p.stream.Connect(connectCtx); err != nil {
return err
return fmt.Errorf("connect stream: %w", err)
}
p.startWriter.Do(func() {
@@ -178,7 +190,7 @@ func (p *streamTransport) Send(data []byte) error {
p.ackMu.Unlock()
}()
for attempt := 0; attempt < maxSendAttempts; attempt++ {
for range maxSendAttempts {
for idx, fragment := range fragments {
frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment)
if err := p.enqueueFrame(frame, false); err != nil {
@@ -210,7 +222,9 @@ func (p *streamTransport) Close() error {
if p.writerUp.Load() {
<-p.writerDone
}
return p.stream.Close()
if err := p.stream.Close(); err != nil {
return fmt.Errorf("close stream: %w", err)
}
}
return nil
}
@@ -256,10 +270,7 @@ func (p *streamTransport) writerLoop() {
ticker := time.NewTicker(defaultFrameInterval)
defer ticker.Stop()
idle, err := buildVideoAccessUnit(nil)
if err != nil {
return
}
idle := buildVideoAccessUnit(nil)
for {
select {
@@ -273,10 +284,7 @@ func (p *streamTransport) writerLoop() {
sample := idle
if payload != nil {
sample, err = buildVideoAccessUnit(payload)
if err != nil {
continue
}
sample = buildVideoAccessUnit(payload)
}
_ = p.track.WriteSample(media.Sample{
@@ -371,14 +379,7 @@ func (p *streamTransport) handleSample(sample []byte) {
}
}
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
p.recvMu.Lock()
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
p.recvMu.Unlock()
p.sendAck(frame.seq, frame.crc)
return
}
func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) {
msg, ok := p.inbound[frame.seq]
if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) {
msg = &inboundMessage{
@@ -389,33 +390,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) {
}
p.inbound[frame.seq] = msg
}
if int(frame.fragIdx) >= len(msg.frags) {
p.recvMu.Unlock()
return
return nil, false
}
if msg.frags[frame.fragIdx] == nil {
chunk := make([]byte, len(frame.payload))
copy(chunk, frame.payload)
msg.frags[frame.fragIdx] = chunk
msg.remain--
}
return msg, msg.remain == 0
}
if msg.remain > 0 {
func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte {
data := make([]byte, 0, msg.totalLen)
for _, frag := range msg.frags {
data = append(data, frag...)
}
if uint32(len(data)) > msg.totalLen { //nolint:gosec
data = data[:msg.totalLen]
}
return data
}
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
p.recvMu.Lock()
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
p.recvMu.Unlock()
p.sendAck(frame.seq, frame.crc)
return
}
msg, complete := p.upsertInbound(frame)
if msg == nil || !complete {
p.recvMu.Unlock()
return
}
delete(p.inbound, frame.seq)
data := make([]byte, 0, msg.totalLen)
for _, frag := range msg.frags {
data = append(data, frag...)
}
if uint32(len(data)) > msg.totalLen {
data = data[:msg.totalLen]
}
data := p.assembleMessage(msg)
if crc32.ChecksumIEEE(data) != msg.crc {
p.recvMu.Unlock()
@@ -480,9 +493,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
out[5] = frameTypeData
binary.BigEndian.PutUint32(out[6:10], seq)
binary.BigEndian.PutUint32(out[10:14], crc)
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen))
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx))
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal))
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec
copy(out[22:], payload)
return out
}
@@ -499,27 +512,27 @@ func encodeAckFrame(seq, crc uint32) []byte {
func decodeTransportFrame(data []byte) (transportFrame, error) {
if len(data) < 6 {
return transportFrame{}, fmt.Errorf("frame too short")
return transportFrame{}, ErrFrameTooShort
}
if binary.BigEndian.Uint32(data[0:4]) != protocolMagic {
return transportFrame{}, fmt.Errorf("unexpected frame magic")
return transportFrame{}, ErrUnexpectedMagic
}
if data[4] != protocolVersion {
return transportFrame{}, fmt.Errorf("unexpected frame version")
return transportFrame{}, ErrUnexpectedVersion
}
frame := transportFrame{typ: data[5]}
switch frame.typ {
case frameTypeAck:
if len(data) < 14 {
return transportFrame{}, fmt.Errorf("ack too short")
return transportFrame{}, ErrAckTooShort
}
frame.seq = binary.BigEndian.Uint32(data[6:10])
frame.crc = binary.BigEndian.Uint32(data[10:14])
return frame, nil
case frameTypeData:
if len(data) < 22 {
return transportFrame{}, fmt.Errorf("data too short")
return transportFrame{}, ErrDataTooShort
}
frame.seq = binary.BigEndian.Uint32(data[6:10])
frame.crc = binary.BigEndian.Uint32(data[10:14])
@@ -529,6 +542,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
frame.payload = append([]byte(nil), data[22:]...)
return frame, nil
default:
return transportFrame{}, fmt.Errorf("unexpected frame type")
return transportFrame{}, ErrUnexpectedFrameType
}
}

View File

@@ -7,10 +7,7 @@ import (
func TestSEIRoundTrip(t *testing.T) {
payload := []byte("hello over seichannel")
accessUnit, err := buildVideoAccessUnit(payload)
if err != nil {
t.Fatalf("buildVideoAccessUnit failed: %v", err)
}
accessUnit := buildVideoAccessUnit(payload)
got, err := extractVideoPayloads(accessUnit)
if err != nil {

View File

@@ -58,6 +58,7 @@ type Config struct {
// Factory creates a transport instance.
type Factory func(ctx context.Context, cfg Config) (Transport, error)
//nolint:gochecknoglobals
var registry = make(map[string]Factory)
// Register adds a transport factory to the registry.

View File

@@ -2,11 +2,13 @@ package videochannel
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -27,6 +29,12 @@ var (
ErrFFmpegUnavailable = errors.New("ffmpeg is required for videochannel")
// ErrUnsupportedVideoCodec is returned when videochannel cannot decode the negotiated codec.
ErrUnsupportedVideoCodec = errors.New("unsupported video codec")
// ErrEncoderTimeout is returned when the encoder does not produce a frame within the deadline.
ErrEncoderTimeout = errors.New("encoder timeout")
// ErrPopFrameTimeout is returned when no decoded frame is available within the deadline.
ErrPopFrameTimeout = errors.New("pop frame timeout")
// ErrUnexpectedFrameSize is returned when the raw frame size does not match expectations.
ErrUnexpectedFrameSize = errors.New("unexpected encoder frame size")
)
type codecSpec struct {
@@ -38,8 +46,7 @@ type codecSpec struct {
encodeArgs []string
}
func codecSpecForCarrier(carrier string) codecSpec {
// Natural default for most WebRTC providers
func codecSpecForCarrier(_ string) codecSpec {
return vp8CodecSpec()
}
@@ -120,6 +127,49 @@ func vp8CodecSpec() codecSpec {
}
}
func resolveEncoderCodec(spec codecSpec, hw string) string {
if hw != "nvenc" {
return spec.encoder
}
switch spec.mimeType {
case webrtc.MimeTypeH264:
return "h264_nvenc"
case webrtc.MimeTypeVP8:
return "vp8_nvenc"
case webrtc.MimeTypeVP9:
return "vp9_nvenc"
case webrtc.MimeTypeAV1:
return "av1_nvenc"
default:
return spec.encoder
}
}
func buildEncoderArgs(spec codecSpec, vcodec string, width, height, fps int, bitrate string) []string {
args := []string{
"-loglevel", "error", "-threads", "1",
"-f", "rawvideo",
"-pix_fmt", "gray",
"-video_size", strconv.Itoa(width) + "x" + strconv.Itoa(height),
"-framerate", strconv.Itoa(fps),
"-i", "pipe:0",
"-an",
}
if strings.HasSuffix(vcodec, "_nvenc") {
args = append(args, "-c:v", vcodec, "-preset", "p1", "-tune", "ull", "-rc", "vbr")
} else {
args = append(args, spec.encodeArgs...)
}
args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate)
if spec.mimeType == webrtc.MimeTypeH264 {
return append(args, "-f", "h264", "pipe:1")
}
return append(args, "-f", "ivf", "pipe:1")
}
type ffmpegEncoder struct {
cmd *exec.Cmd
stdin io.WriteCloser
@@ -134,62 +184,20 @@ type ffmpegEncoder struct {
err error
}
func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string) (*ffmpegEncoder, error) {
func newFFmpegEncoder(
ctx context.Context,
spec codecSpec,
width, height, fps int,
bitrate, hw string,
) (*ffmpegEncoder, error) {
if _, err := exec.LookPath("ffmpeg"); err != nil {
return nil, ErrFFmpegUnavailable
}
args := []string{"-loglevel", "error", "-threads", "1"}
vcodec := resolveEncoderCodec(spec, hw)
args := buildEncoderArgs(spec, vcodec, width, height, fps, bitrate)
// Determine encoder binary based on HW flag
vcodec := spec.encoder
if hw == "nvenc" {
switch spec.mimeType {
case webrtc.MimeTypeH264:
vcodec = "h264_nvenc"
case webrtc.MimeTypeVP8:
vcodec = "vp8_nvenc"
case webrtc.MimeTypeVP9:
vcodec = "vp9_nvenc"
case webrtc.MimeTypeAV1:
vcodec = "av1_nvenc"
}
}
inputPixFmt := "gray"
frameSize := width * height
args = append(args,
"-f", "rawvideo",
"-pix_fmt", inputPixFmt,
"-video_size", fmt.Sprintf("%dx%d", width, height),
"-framerate", fmt.Sprintf("%d", fps),
"-i", "pipe:0",
"-an",
)
// Apply hardware specific flags if using NVENC
if strings.HasSuffix(vcodec, "_nvenc") {
args = append(args,
"-c:v", vcodec,
"-preset", "p1",
"-tune", "ull",
"-rc", "vbr",
)
} else {
// Use software encoder args from spec
args = append(args, spec.encodeArgs...)
}
args = append(args, "-g", "1", "-pix_fmt", "yuv420p", "-b:v", bitrate)
if spec.mimeType == webrtc.MimeTypeH264 {
args = append(args, "-f", "h264", "pipe:1")
} else {
args = append(args, "-f", "ivf", "pipe:1")
}
cmd := exec.Command("ffmpeg", args...)
cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("encoder stdin: %w", err)
@@ -212,7 +220,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string
frames: make(chan []byte, 8),
width: width,
height: height,
frameSize: frameSize,
frameSize: width * height,
}
if spec.mimeType == webrtc.MimeTypeH264 {
@@ -225,7 +233,7 @@ func newFFmpegEncoder(spec codecSpec, width, height, fps int, bitrate, hw string
func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) {
if len(frame) != e.frameSize {
return nil, fmt.Errorf("unexpected encoder frame size: %d (expected %d)", len(frame), e.frameSize)
return nil, fmt.Errorf("%w: got %d expected %d", ErrUnexpectedFrameSize, len(frame), e.frameSize)
}
if err := e.processErr(); err != nil {
return nil, err
@@ -244,7 +252,7 @@ func (e *ffmpegEncoder) EncodeFrame(frame []byte) ([]byte, error) {
if err := e.processErr(); err != nil {
return nil, err
}
return nil, fmt.Errorf("encoder timeout")
return nil, ErrEncoderTimeout
}
}
@@ -327,6 +335,43 @@ func (e *ffmpegEncoder) processErr() error {
return nil
}
func resolveDecoderName(spec codecSpec, hw string) string {
if hw != "nvenc" {
return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
}
switch spec.mimeType {
case webrtc.MimeTypeH264:
return "h264_cuvid"
case webrtc.MimeTypeVP8:
return "vp8_cuvid"
case webrtc.MimeTypeVP9:
return "vp9_cuvid"
default:
return strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
}
}
func buildDecoderArgs(spec codecSpec, decoderName string, width, height int, outputPixFmt string) []string {
args := []string{"-loglevel", "error", "-threads", "1"}
if spec.mimeType == webrtc.MimeTypeH264 {
args = append(args, "-f", "h264")
} else {
args = append(args, "-f", "ivf")
}
vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt)
return append(args,
"-flags", "low_delay",
"-vcodec", decoderName,
"-i", "pipe:0",
"-an",
"-vf", vfFilter,
"-pix_fmt", outputPixFmt,
"-f", "rawvideo",
"pipe:1",
)
}
type ffmpegDecoder struct {
cmd *exec.Cmd
stdin io.WriteCloser
@@ -341,46 +386,20 @@ type ffmpegDecoder struct {
err error
}
func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpegDecoder, error) {
func newFFmpegDecoder(
ctx context.Context,
spec codecSpec,
width, height, fps int,
hw string,
) (*ffmpegDecoder, error) {
if _, err := exec.LookPath("ffmpeg"); err != nil {
return nil, ErrFFmpegUnavailable
}
decoderName := strings.ToLower(strings.TrimPrefix(spec.mimeType, "video/"))
if hw == "nvenc" {
switch spec.mimeType {
case webrtc.MimeTypeH264:
decoderName = "h264_cuvid"
case webrtc.MimeTypeVP8:
decoderName = "vp8_cuvid"
case webrtc.MimeTypeVP9:
decoderName = "vp9_cuvid"
}
}
decoderName := resolveDecoderName(spec, hw)
args := buildDecoderArgs(spec, decoderName, width, height, "gray")
outputPixFmt := "gray"
frameSize := width * height
args := []string{"-loglevel", "error", "-threads", "1"}
if spec.mimeType == webrtc.MimeTypeH264 {
args = append(args, "-f", "h264")
} else {
args = append(args, "-f", "ivf")
}
vfFilter := fmt.Sprintf("scale=%d:%d:flags=neighbor,format=%s", width, height, outputPixFmt)
args = append(args,
"-flags", "low_delay",
"-vcodec", decoderName,
"-i", "pipe:0",
"-an",
"-vf", vfFilter,
"-pix_fmt", outputPixFmt,
"-f", "rawvideo",
"pipe:1",
)
cmd := exec.Command("ffmpeg", args...)
cmd := exec.CommandContext(ctx, "ffmpeg", args...) //nolint:gosec
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("decoder stdin: %w", err)
@@ -402,7 +421,7 @@ func newFFmpegDecoder(spec codecSpec, width, height, fps int, hw string) (*ffmpe
stderr: stderr,
frames: make(chan []byte, 32),
mimeType: spec.mimeType,
frameSize: frameSize,
frameSize: width * height,
}
if spec.mimeType != webrtc.MimeTypeH264 {
@@ -441,7 +460,7 @@ func (d *ffmpegDecoder) PopFrame() ([]byte, error) {
}
return frame, nil
case <-time.After(10 * time.Second):
return nil, fmt.Errorf("pop frame timeout")
return nil, ErrPopFrameTimeout
}
}
@@ -515,9 +534,9 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er
binary.LittleEndian.PutUint16(header[4:6], 0)
binary.LittleEndian.PutUint16(header[6:8], 32)
copy(header[8:12], []byte(fourCC))
binary.LittleEndian.PutUint16(header[12:14], uint16(width))
binary.LittleEndian.PutUint16(header[14:16], uint16(height))
binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate))
binary.LittleEndian.PutUint16(header[12:14], uint16(width)) //nolint:gosec
binary.LittleEndian.PutUint16(header[14:16], uint16(height)) //nolint:gosec
binary.LittleEndian.PutUint32(header[16:20], uint32(frameRate)) //nolint:gosec
binary.LittleEndian.PutUint32(header[20:24], 1)
binary.LittleEndian.PutUint32(header[24:28], 0)
binary.LittleEndian.PutUint32(header[28:32], 0)
@@ -526,7 +545,7 @@ func writeIVFHeader(w io.Writer, fourCC string, width, height, frameRate int) er
func writeIVFFrame(w io.Writer, pts uint64, frame []byte) error {
header := make([]byte, 12)
binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame)))
binary.LittleEndian.PutUint32(header[0:4], uint32(len(frame))) //nolint:gosec
binary.LittleEndian.PutUint64(header[4:12], pts)
if err := writeAll(w, header); err != nil {
return err
@@ -538,9 +557,10 @@ func writeAll(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
return fmt.Errorf("write: %w", err)
}
data = data[n:]
}
return nil
}

View File

@@ -2,7 +2,7 @@ package videochannel
import (
"encoding/binary"
"fmt"
"errors"
)
const (
@@ -12,6 +12,21 @@ const (
frameTypeAck byte = 2
)
var (
// ErrFrameTooShort is returned when the received frame is too short to decode.
ErrFrameTooShort = errors.New("frame too short")
// ErrUnexpectedMagic is returned when the frame magic bytes do not match.
ErrUnexpectedMagic = errors.New("unexpected frame magic")
// ErrUnexpectedVersion is returned when the frame protocol version does not match.
ErrUnexpectedVersion = errors.New("unexpected frame version")
// ErrAckTooShort is returned when the ack frame is shorter than expected.
ErrAckTooShort = errors.New("ack frame too short")
// ErrDataTooShort is returned when the data frame is shorter than expected.
ErrDataTooShort = errors.New("data frame too short")
// ErrUnexpectedFrameType is returned for unknown frame type bytes.
ErrUnexpectedFrameType = errors.New("unexpected frame type")
)
type transportFrame struct {
typ byte
seq uint32
@@ -56,9 +71,9 @@ func encodeDataFrame(seq, crc uint32, totalLen, fragIdx, fragTotal int, payload
out[5] = frameTypeData
binary.BigEndian.PutUint32(out[6:10], seq)
binary.BigEndian.PutUint32(out[10:14], crc)
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen))
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx))
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal))
binary.BigEndian.PutUint32(out[14:18], uint32(totalLen)) //nolint:gosec
binary.BigEndian.PutUint16(out[18:20], uint16(fragIdx)) //nolint:gosec
binary.BigEndian.PutUint16(out[20:22], uint16(fragTotal)) //nolint:gosec
copy(out[22:], payload)
return out
}
@@ -75,27 +90,27 @@ func encodeAckFrame(seq, crc uint32) []byte {
func decodeTransportFrame(data []byte) (transportFrame, error) {
if len(data) < 6 {
return transportFrame{}, fmt.Errorf("frame too short")
return transportFrame{}, ErrFrameTooShort
}
if binary.BigEndian.Uint32(data[0:4]) != protocolMagic {
return transportFrame{}, fmt.Errorf("unexpected frame magic")
return transportFrame{}, ErrUnexpectedMagic
}
if data[4] != protocolVersion {
return transportFrame{}, fmt.Errorf("unexpected frame version")
return transportFrame{}, ErrUnexpectedVersion
}
frame := transportFrame{typ: data[5]}
switch frame.typ {
case frameTypeAck:
if len(data) < 14 {
return transportFrame{}, fmt.Errorf("ack too short")
return transportFrame{}, ErrAckTooShort
}
frame.seq = binary.BigEndian.Uint32(data[6:10])
frame.crc = binary.BigEndian.Uint32(data[10:14])
return frame, nil
case frameTypeData:
if len(data) < 22 {
return transportFrame{}, fmt.Errorf("data too short")
return transportFrame{}, ErrDataTooShort
}
frame.seq = binary.BigEndian.Uint32(data[6:10])
frame.crc = binary.BigEndian.Uint32(data[10:14])
@@ -105,6 +120,6 @@ func decodeTransportFrame(data []byte) (transportFrame, error) {
frame.payload = append([]byte(nil), data[22:]...)
return frame, nil
default:
return transportFrame{}, fmt.Errorf("unexpected frame type")
return transportFrame{}, ErrUnexpectedFrameType
}
}

View File

@@ -70,9 +70,8 @@ type streamTransport struct {
videoCodec string
videoTileModule int
videoTileRS int
runCtx context.Context //nolint:containedctx
// cached encoded idle frame — rendered and encoded once, reused on every tick
// where the outbound queue is empty to avoid re-encoding an identical blank frame.
idleFrame []byte
idleFrameMu sync.Mutex
}
@@ -144,6 +143,7 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error)
videoCodec: cfg.VideoCodec,
videoTileModule: tileModule,
videoTileRS: tileRS,
runCtx: ctx,
}
if err := stream.AddTrack(track); err != nil {
@@ -159,14 +159,14 @@ func (p *streamTransport) Connect(ctx context.Context) error {
connectCtx, cancel := context.WithTimeout(ctx, defaultConnectTimeout)
defer cancel()
encoder, err := newFFmpegEncoder(p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW)
encoder, err := newFFmpegEncoder(ctx, p.codec, p.videoW, p.videoH, p.videoFPS, p.videoBitrate, p.videoHW)
if err != nil {
return err
return fmt.Errorf("new encoder: %w", err)
}
if err := p.stream.Connect(connectCtx); err != nil {
_ = encoder.Close()
return err
return fmt.Errorf("connect stream: %w", err)
}
p.encoderMu.Lock()
@@ -212,7 +212,7 @@ func (p *streamTransport) Send(data []byte) error {
p.ackMu.Unlock()
}()
for attempt := 0; attempt < maxSendAttempts; attempt++ {
for range maxSendAttempts {
for idx, fragment := range fragments {
frame := encodeDataFrame(seq, crc, len(data), idx, len(fragments), fragment)
if err := p.enqueueFrame(frame, false); err != nil {
@@ -257,7 +257,9 @@ func (p *streamTransport) Close() error {
if p.writerUp.Load() {
<-p.writerDone
}
return p.stream.Close()
if err := p.stream.Close(); err != nil {
return fmt.Errorf("close stream: %w", err)
}
}
return nil
}
@@ -301,6 +303,47 @@ func (p *streamTransport) Features() transport.Features {
}
}
func (p *streamTransport) writeIdleFrame(enc *ffmpegEncoder, frameDuration time.Duration) {
p.idleFrameMu.Lock()
cached := p.idleFrame
p.idleFrameMu.Unlock()
if cached == nil {
rawFrame, err := p.renderFrame(nil)
if err != nil {
logger.Debugf("videochannel render idle error: %v", err)
return
}
sample, err := enc.EncodeFrame(rawFrame)
if err != nil {
logger.Warnf("videochannel encoder idle error: %v", err)
return
}
p.idleFrameMu.Lock()
p.idleFrame = sample
p.idleFrameMu.Unlock()
cached = sample
}
_ = p.track.WriteSample(media.Sample{Data: cached, Duration: frameDuration})
}
func (p *streamTransport) writePayloadFrame(enc *ffmpegEncoder, payload []byte, frameDuration time.Duration) {
rawFrame, err := p.renderFrame(payload)
if err != nil {
logger.Debugf("videochannel render error: %v", err)
return
}
sample, err := enc.EncodeFrame(rawFrame)
if err != nil {
logger.Warnf("videochannel encoder error: %v", err)
return
}
_ = p.track.WriteSample(media.Sample{Data: sample, Duration: frameDuration})
}
func (p *streamTransport) writerLoop() {
defer close(p.writerDone)
defer func() {
@@ -334,56 +377,22 @@ func (p *streamTransport) writerLoop() {
continue
}
// idle frame: payload is nil — reuse previously encoded sample to avoid
// re-rendering and re-encoding an identical blank frame every tick.
if payload == nil {
p.idleFrameMu.Lock()
cached := p.idleFrame
p.idleFrameMu.Unlock()
if cached == nil {
// first time — render + encode once, then cache
rawFrame, err := renderVisualFrame(nil, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS)
if err != nil {
logger.Debugf("videochannel render idle error: %v", err)
continue
p.writeIdleFrame(enc, frameDuration)
} else {
p.writePayloadFrame(enc, payload, frameDuration)
}
}
sample, err := enc.EncodeFrame(rawFrame)
if err != nil {
logger.Warnf("videochannel encoder idle error: %v", err)
continue
}
p.idleFrameMu.Lock()
p.idleFrame = sample
p.idleFrameMu.Unlock()
cached = sample
}
_ = p.track.WriteSample(media.Sample{
Data: cached,
Duration: frameDuration,
})
continue
}
rawFrame, err := renderVisualFrame(payload, p.videoW, p.videoH, p.videoCodec, p.videoQRRecovery, p.videoTileModule, p.videoTileRS)
if err != nil {
logger.Debugf("videochannel render error: %v", err)
continue
}
sample, err := enc.EncodeFrame(rawFrame)
if err != nil {
logger.Warnf("videochannel encoder error: %v", err)
continue
}
_ = p.track.WriteSample(media.Sample{
Data: sample,
Duration: frameDuration,
})
}
}
func (p *streamTransport) renderFrame(payload []byte) ([]byte, error) {
return renderVisualFrame(
payload,
p.videoW, p.videoH,
p.videoCodec, p.videoQRRecovery,
p.videoTileModule, p.videoTileRS,
)
}
func (p *streamTransport) nextOutboundFrame() ([]byte, bool) {
@@ -425,32 +434,7 @@ func (p *streamTransport) enqueueFrame(frame []byte, priority bool) error {
}
}
func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
codec, ok := codecSpecForMime(track.Codec().MimeType)
if !ok {
logger.Warnf("videochannel unsupported remote codec: %s", track.Codec().MimeType)
return
}
decoder, err := newFFmpegDecoder(codec, p.videoW, p.videoH, p.videoFPS, p.videoHW)
if err != nil {
logger.Warnf("videochannel decoder init failed: %v", err)
return
}
p.decoderMu.Lock()
if p.closed.Load() {
p.decoderMu.Unlock()
_ = decoder.Close()
return
}
if p.decoder != nil {
_ = p.decoder.Close()
}
p.decoder = decoder
p.decoderMu.Unlock()
go func() {
func (p *streamTransport) popDecoderFrames(decoder *ffmpegDecoder) {
defer func() {
p.decoderMu.Lock()
if p.decoder == decoder {
@@ -476,9 +460,9 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc
}
p.handleFrame(frame)
}
}()
}
go func() {
func (p *streamTransport) readDecoderInput(track *webrtc.TrackRemote, decoder *ffmpegDecoder, codec codecSpec) {
sb := samplebuilder.New(sampleBuilderMaxLate, codec.depacketizer(), track.Codec().ClockRate)
for {
select {
@@ -503,7 +487,35 @@ func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc
}
}
}
}()
}
func (p *streamTransport) handleRemoteTrack(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
codec, ok := codecSpecForMime(track.Codec().MimeType)
if !ok {
logger.Warnf("videochannel unsupported remote codec: %s", track.Codec().MimeType)
return
}
decoder, err := newFFmpegDecoder(p.runCtx, codec, p.videoW, p.videoH, p.videoFPS, p.videoHW)
if err != nil {
logger.Warnf("videochannel decoder init failed: %v", err)
return
}
p.decoderMu.Lock()
if p.closed.Load() {
p.decoderMu.Unlock()
_ = decoder.Close()
return
}
if p.decoder != nil {
_ = p.decoder.Close()
}
p.decoder = decoder
p.decoderMu.Unlock()
go p.popDecoderFrames(decoder)
go p.readDecoderInput(track, decoder, codec)
}
func (p *streamTransport) handleFrame(frame []byte) {
@@ -531,14 +543,7 @@ func (p *streamTransport) handleFrame(frame []byte) {
}
}
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
p.recvMu.Lock()
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
p.recvMu.Unlock()
p.sendAck(frame.seq, frame.crc)
return
}
func (p *streamTransport) upsertInbound(frame transportFrame) (*inboundMessage, bool) {
msg, ok := p.inbound[frame.seq]
if !ok || msg.crc != frame.crc || msg.totalLen != frame.totalLen || len(msg.frags) != int(frame.fragTotal) {
msg = &inboundMessage{
@@ -549,33 +554,45 @@ func (p *streamTransport) handleInboundFrame(frame transportFrame) {
}
p.inbound[frame.seq] = msg
}
if int(frame.fragIdx) >= len(msg.frags) {
p.recvMu.Unlock()
return
return nil, false
}
if msg.frags[frame.fragIdx] == nil {
chunk := make([]byte, len(frame.payload))
copy(chunk, frame.payload)
msg.frags[frame.fragIdx] = chunk
msg.remain--
}
return msg, msg.remain == 0
}
if msg.remain > 0 {
func (p *streamTransport) assembleMessage(msg *inboundMessage) []byte {
data := make([]byte, 0, msg.totalLen)
for _, frag := range msg.frags {
data = append(data, frag...)
}
if uint32(len(data)) > msg.totalLen { //nolint:gosec
data = data[:msg.totalLen]
}
return data
}
func (p *streamTransport) handleInboundFrame(frame transportFrame) {
p.recvMu.Lock()
if crc, ok := p.delivered[frame.seq]; ok && crc == frame.crc {
p.recvMu.Unlock()
p.sendAck(frame.seq, frame.crc)
return
}
msg, complete := p.upsertInbound(frame)
if msg == nil || !complete {
p.recvMu.Unlock()
return
}
delete(p.inbound, frame.seq)
data := make([]byte, 0, msg.totalLen)
for _, frag := range msg.frags {
data = append(data, frag...)
}
if uint32(len(data)) > msg.totalLen {
data = data[:msg.totalLen]
}
data := p.assembleMessage(msg)
if crc32.ChecksumIEEE(data) != msg.crc {
p.recvMu.Unlock()

View File

@@ -1,6 +1,7 @@
package videochannel
import (
"errors"
"fmt"
"strings"
@@ -8,6 +9,9 @@ import (
grtile "github.com/zarazaex69/gr/tile"
)
// ErrUnexpectedQRFrameSize is returned when the decoded frame size does not match the expected dimensions.
var ErrUnexpectedQRFrameSize = errors.New("unexpected qr frame size")
func eccLevel(level string) grqr.ECCLevel {
switch level {
case "medium":
@@ -21,7 +25,12 @@ func eccLevel(level string) grqr.ECCLevel {
}
}
func renderVisualFrame(payload []byte, width, height int, codec, recoveryLevel string, tileModule, tileRS int) ([]byte, error) {
func renderVisualFrame(
payload []byte,
width, height int,
codec, recoveryLevel string,
tileModule, tileRS int,
) ([]byte, error) {
if codec == "tile" {
return renderTileFrame(payload, tileModule, tileRS)
}
@@ -47,7 +56,11 @@ func renderQRFrame(payload []byte, width, height int, recoveryLevel string) ([]b
return nil, fmt.Errorf("qr codec: %w", err)
}
return c.Encode(payload)
result, err := c.Encode(payload)
if err != nil {
return nil, fmt.Errorf("qr encode: %w", err)
}
return result, nil
}
func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) {
@@ -64,7 +77,11 @@ func renderTileFrame(payload []byte, tileModule, tileRS int) ([]byte, error) {
return nil, fmt.Errorf("tile codec: %w", err)
}
return c.Encode(payload, 0, 1)
result, err := c.Encode(payload, 0, 1)
if err != nil {
return nil, fmt.Errorf("tile encode: %w", err)
}
return result, nil
}
func extractVisualPayload(frame []byte, width, height int, codec string, tileModule, tileRS int) ([]byte, error) {
@@ -76,7 +93,8 @@ func extractVisualPayload(frame []byte, width, height int, codec string, tileMod
func extractQRPayload(frame []byte, width, height int) ([]byte, error) {
if len(frame) != width*height {
return nil, fmt.Errorf("unexpected frame size: %d (expected %dx%d=%d)", len(frame), width, height, width*height)
return nil, fmt.Errorf("%w: got %d expected %dx%d=%d",
ErrUnexpectedQRFrameSize, len(frame), width, height, width*height)
}
c, err := grqr.New(grqr.Config{
@@ -111,7 +129,7 @@ func extractTilePayload(frame []byte, tileModule, tileRS int) ([]byte, error) {
result, err := c.Decode(frame)
if err != nil {
return nil, nil
return nil, nil //nolint:nilerr
}
return result.Payload, nil

View File

@@ -1,3 +1,4 @@
// Package vp8channel provides byte transport over VP8 video frames using KCP.
package vp8channel
import (
@@ -58,7 +59,7 @@ type kcpRuntime struct {
func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) {
c := newKCPConn(out, inboundQueueSize)
sess, err := kcp.NewConn3(kcpConvID, fakeAddr, nil, 0, 0, c)
sess, err := kcp.NewConn3(kcpConvID, fakeUDPAddr(), nil, 0, 0, c)
if err != nil {
_ = c.Close()
return nil, fmt.Errorf("kcp new conn: %w", err)
@@ -71,7 +72,6 @@ func startKCP(out chan<- []byte, onData func([]byte)) (*kcpRuntime, error) {
sess.SetNoDelay(1, 10, 2, 1)
sess.SetWindowSize(kcpSndWnd, kcpRcvWnd)
sess.SetMtu(kcpMTU)
sess.SetStreamMode(true) // see kcpLenPrefix comment above
sess.SetACKNoDelay(true)
sess.SetWriteDelay(false)
@@ -127,16 +127,17 @@ func (r *kcpRuntime) send(msg []byte) error {
return ErrKCPMessageTooLarge
}
var hdr [kcpLenPrefix]byte
//nolint:gosec
binary.BigEndian.PutUint32(hdr[:], uint32(len(msg)))
r.writeMu.Lock()
defer r.writeMu.Unlock()
if _, err := r.sess.Write(hdr[:]); err != nil {
return err
return fmt.Errorf("kcp write header: %w", err)
}
if _, err := r.sess.Write(msg); err != nil {
return err
return fmt.Errorf("kcp write payload: %w", err)
}
return nil
}

View File

@@ -6,10 +6,9 @@ import (
"time"
)
// fakeAddr is a placeholder address used by the KCP session. The underlying
// "packet conn" is a point-to-point pipe over the VP8 carrier and has no real
// notion of an address, but kcp-go's API requires one.
var fakeAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1}
func fakeUDPAddr() *net.UDPAddr {
return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1}
}
// kcpConn is a net.PacketConn implementation that bridges kcp-go on top of
// the vp8channel byte-message carrier.
@@ -62,7 +61,7 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) {
if !deadline.IsZero() {
d := time.Until(deadline)
if d <= 0 {
return 0, nil, errTimeout{}
return 0, nil, TimeoutError{}
}
t := time.NewTimer(d)
defer t.Stop()
@@ -72,11 +71,11 @@ func (c *kcpConn) ReadFrom(p []byte) (int, net.Addr, error) {
select {
case msg := <-c.in:
n := copy(p, msg)
return n, fakeAddr, nil
return n, fakeUDPAddr(), nil
case <-c.closed:
return 0, nil, net.ErrClosed
case <-timerC:
return 0, nil, errTimeout{}
return 0, nil, TimeoutError{}
}
}
@@ -92,7 +91,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) {
if !deadline.IsZero() {
d := time.Until(deadline)
if d <= 0 {
return 0, errTimeout{}
return 0, TimeoutError{}
}
t := time.NewTimer(d)
defer t.Stop()
@@ -105,7 +104,7 @@ func (c *kcpConn) WriteTo(p []byte, _ net.Addr) (int, error) {
case <-c.closed:
return 0, net.ErrClosed
case <-timerC:
return 0, errTimeout{}
return 0, TimeoutError{}
}
}
@@ -114,7 +113,7 @@ func (c *kcpConn) Close() error {
return nil
}
func (c *kcpConn) LocalAddr() net.Addr { return fakeAddr }
func (c *kcpConn) LocalAddr() net.Addr { return fakeUDPAddr() }
func (c *kcpConn) SetDeadline(t time.Time) error {
_ = c.SetReadDeadline(t)
@@ -136,8 +135,13 @@ func (c *kcpConn) SetWriteDeadline(t time.Time) error {
return nil
}
type errTimeout struct{}
// TimeoutError is a net.Error indicating a deadline exceeded.
type TimeoutError struct{}
func (errTimeout) Error() string { return "i/o timeout" }
func (errTimeout) Timeout() bool { return true }
func (errTimeout) Temporary() bool { return true }
func (TimeoutError) Error() string { return "i/o timeout" }
// Timeout reports that this error is a timeout.
func (TimeoutError) Timeout() bool { return true }
// Temporary reports that this error is temporary.
func (TimeoutError) Temporary() bool { return true }

View File

@@ -27,14 +27,13 @@ const (
)
var (
// ErrVideoTrackUnsupported is returned when a carrier cannot expose video tracks.
ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks")
// ErrTransportClosed is returned when operations are attempted on a closed transport.
ErrTransportClosed = errors.New("vp8channel transport closed")
)
// vp8Keepalive is a minimal VP8 keyframe used as idle filler so that the SFU
// keeps the track flowing when KCP has nothing to send. It is never delivered
// to KCP because KCP packets always start with the convid (0xC0FFEE01 LE)
// and would never collide with this keyframe payload.
//nolint:gochecknoglobals
var vp8Keepalive = []byte{
0x30, 0x01, 0x00, 0x9d, 0x01, 0x2a, 0x10, 0x00,
0x10, 0x00, 0x00, 0x47, 0x08, 0x85, 0x85, 0x88,
@@ -64,6 +63,7 @@ type streamTransport struct {
kcpMu sync.RWMutex
}
// New creates a vp8channel transport backed by a carrier-specific provider.
func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) {
session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{
RoomURL: cfg.RoomURL,
@@ -126,7 +126,7 @@ func (p *streamTransport) Connect(ctx context.Context) error {
defer cancel()
if err := p.stream.Connect(connectCtx); err != nil {
return err
return fmt.Errorf("connect stream: %w", err)
}
var startErr error
@@ -179,7 +179,9 @@ func (p *streamTransport) Close() error {
if p.writerUp.Load() {
<-p.writerDone
}
return p.stream.Close()
if err := p.stream.Close(); err != nil {
return fmt.Errorf("close stream: %w", err)
}
}
return nil
}
@@ -302,14 +304,62 @@ func (p *streamTransport) drainTrack(track *webrtc.TrackRemote) {
}
}
func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
var vp8Pkt codecs.VP8Packet
var frameBuf []byte
buf := make([]byte, rtpBufSize)
type vp8FrameState struct {
vp8Pkt codecs.VP8Packet
frameBuf []byte
lastSeq uint16
haveLastSeq bool
frameValid bool
}
var lastSeq uint16
var haveLastSeq bool
frameValid := false
// processRTPPacket returns a complete KCP frame when the VP8 frame is fully assembled, nil otherwise.
// Detects packet loss/reordering to avoid silently corrupting fragmented VP8 frames.
func (s *vp8FrameState) processRTPPacket(pkt *rtp.Packet) []byte {
if s.haveLastSeq && pkt.SequenceNumber != s.lastSeq+1 {
s.frameValid = false
s.frameBuf = s.frameBuf[:0]
}
s.lastSeq = pkt.SequenceNumber
s.haveLastSeq = true
vp8Payload, err := s.vp8Pkt.Unmarshal(pkt.Payload)
if err != nil {
s.frameValid = false
s.frameBuf = s.frameBuf[:0]
return nil
}
if s.vp8Pkt.S == 1 {
s.frameBuf = s.frameBuf[:0]
s.frameValid = true
}
if !s.frameValid {
return nil
}
s.frameBuf = append(s.frameBuf, vp8Payload...)
if !pkt.Marker {
return nil
}
defer func() {
s.frameBuf = s.frameBuf[:0]
s.frameValid = false
}()
if len(s.frameBuf) >= 4 && s.frameBuf[0] == kcpMagic {
frame := make([]byte, len(s.frameBuf))
copy(frame, s.frameBuf)
return frame
}
return nil
}
func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
var state vp8FrameState
buf := make([]byte, rtpBufSize)
for {
n, _, err := track.Read(buf)
@@ -322,54 +372,16 @@ func (p *streamTransport) readVP8Track(track *webrtc.TrackRemote) {
continue
}
// Detect packet loss / reordering. A single missing RTP packet
// inside a fragmented VP8 frame would otherwise silently corrupt
// the assembled payload (and bleed into the next frame). KCP can
// recover from full-frame drops, but only if the frames it does
// receive are byte-perfect.
if haveLastSeq {
expected := lastSeq + 1
if pkt.SequenceNumber != expected {
frameValid = false
frameBuf = frameBuf[:0]
}
}
lastSeq = pkt.SequenceNumber
haveLastSeq = true
vp8Payload, err := vp8Pkt.Unmarshal(pkt.Payload)
if err != nil {
frameValid = false
frameBuf = frameBuf[:0]
frame := state.processRTPPacket(pkt)
if frame == nil {
continue
}
if vp8Pkt.S == 1 {
frameBuf = frameBuf[:0]
frameValid = true
}
if !frameValid {
continue
}
frameBuf = append(frameBuf, vp8Payload...)
if pkt.Marker {
if len(frameBuf) >= 4 && frameBuf[0] == kcpMagic {
p.kcpMu.RLock()
rt := p.kcp
p.kcpMu.RUnlock()
if rt != nil {
// Copy out of the shared frame buffer before handing
// the payload off — KCP's deliver path is async.
payload := make([]byte, len(frameBuf))
copy(payload, frameBuf)
rt.deliver(payload)
}
}
frameBuf = frameBuf[:0]
frameValid = false
rt.deliver(frame)
}
}
}

View File

@@ -7,16 +7,64 @@ import (
"time"
)
func pumpPackets(stop <-chan struct{}, from <-chan []byte, to *kcpRuntime) {
for {
select {
case <-stop:
return
case pkt := <-from:
to.deliver(pkt)
}
}
}
func checkMessages(t *testing.T, got, want [][]byte) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("got %d messages, want %d", len(got), len(want))
}
for i, m := range want {
if !bytes.Equal(got[i], m) {
t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(got[i]), len(m))
}
}
}
func buildReceiver(n int) (func([]byte), <-chan struct{}, func() [][]byte) {
var mu sync.Mutex
var recv [][]byte
done := make(chan struct{})
cb := func(msg []byte) {
mu.Lock()
recv = append(recv, append([]byte(nil), msg...))
count := len(recv)
mu.Unlock()
if count == n {
close(done)
}
}
get := func() [][]byte {
mu.Lock()
defer mu.Unlock()
return recv
}
return cb, done, get
}
// TestKCPLoopback runs two KCP runtimes back-to-back through an in-memory
// pipe simulating a perfect carrier. Verifies that messages survive the
// KCP layer with their boundaries intact.
func TestKCPLoopback(t *testing.T) {
msgs := [][]byte{
[]byte("hello"),
bytes.Repeat([]byte("x"), 1000),
bytes.Repeat([]byte("y"), 20000),
}
a2b := make(chan []byte, 256)
b2a := make(chan []byte, 256)
var bRecvMu sync.Mutex
var bRecv [][]byte
doneB := make(chan struct{})
cb, doneB, getRecv := buildReceiver(len(msgs))
rtA, err := startKCP(a2b, nil)
if err != nil {
@@ -24,50 +72,18 @@ func TestKCPLoopback(t *testing.T) {
}
defer rtA.close()
rtB, err := startKCP(b2a, func(msg []byte) {
bRecvMu.Lock()
bRecv = append(bRecv, append([]byte(nil), msg...))
n := len(bRecv)
bRecvMu.Unlock()
if n == 3 {
close(doneB)
}
})
rtB, err := startKCP(b2a, cb)
if err != nil {
t.Fatalf("startKCP B: %v", err)
}
defer rtB.close()
// Pump packets between the two runtimes.
stop := make(chan struct{})
defer close(stop)
go func() {
for {
select {
case <-stop:
return
case pkt := <-a2b:
rtB.deliver(pkt)
}
}
}()
go func() {
for {
select {
case <-stop:
return
case pkt := <-b2a:
rtA.deliver(pkt)
}
}
}()
go pumpPackets(stop, a2b, rtB)
go pumpPackets(stop, b2a, rtA)
msgs := [][]byte{
[]byte("hello"),
bytes.Repeat([]byte("x"), 1000),
bytes.Repeat([]byte("y"), 20000),
}
for _, m := range msgs {
if err := rtA.send(m); err != nil {
t.Fatalf("send: %v", err)
@@ -80,21 +96,10 @@ func TestKCPLoopback(t *testing.T) {
t.Fatal("timeout waiting for messages")
}
bRecvMu.Lock()
defer bRecvMu.Unlock()
if len(bRecv) != len(msgs) {
t.Fatalf("got %d messages, want %d", len(bRecv), len(msgs))
}
for i, m := range msgs {
if !bytes.Equal(bRecv[i], m) {
t.Errorf("msg %d mismatch: got %d bytes, want %d", i, len(bRecv[i]), len(m))
}
}
checkMessages(t, getRecv(), msgs)
}
func TestVP8KeepaliveDoesNotLookLikeKCP(t *testing.T) {
// Keepalive frames must not be mistaken for KCP packets by the receive
// path; otherwise the KCP stack would constantly chew on garbage.
if len(vp8Keepalive) >= 1 && vp8Keepalive[0] == kcpMagic {
t.Errorf("keepalive collides with kcp magic byte 0x%02x", kcpMagic)
}