mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
Merge pull request #24 from openlibrecommunity/refactor/all
Refactor/all
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -46,7 +45,7 @@ var (
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
log.Print(err)
|
||||
logger.Error(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -83,7 +82,7 @@ func run() error {
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
log.Println("Shutting down gracefully...")
|
||||
logger.Info("Shutting down gracefully...")
|
||||
cancel()
|
||||
return waitForShutdown(errCh)
|
||||
case err := <-errCh:
|
||||
@@ -112,12 +111,8 @@ func parseFlags() config {
|
||||
|
||||
func configureLogging(debug bool) {
|
||||
if debug {
|
||||
log.SetFlags(log.Ltime | log.Lshortfile)
|
||||
logger.SetVerbose(true)
|
||||
return
|
||||
}
|
||||
|
||||
log.SetFlags(log.Ltime)
|
||||
}
|
||||
|
||||
func validateConfig(cfg config) error {
|
||||
@@ -135,7 +130,7 @@ func validateConfig(cfg config) error {
|
||||
return errProviderRequired
|
||||
case !validProvider:
|
||||
return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, available)
|
||||
case cfg.roomID == "":
|
||||
case cfg.roomID == "" && cfg.provider != "jazz":
|
||||
return errRoomIDRequired
|
||||
case cfg.mode != "srv" && cfg.mode != "cnc":
|
||||
return errModeRequired
|
||||
@@ -187,8 +182,8 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
|
||||
cfg.provider,
|
||||
roomURL,
|
||||
cfg.keyHex,
|
||||
cfg.socksPort,
|
||||
cfg.socksHost,
|
||||
fmt.Sprintf("%s:%d", cfg.socksHost, cfg.socksPort),
|
||||
cfg.dnsServer,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
@@ -220,11 +215,11 @@ func waitForShutdown(errCh <-chan error) error {
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
log.Println("Shutdown complete")
|
||||
logger.Info("Shutdown complete")
|
||||
}
|
||||
return err
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Println("Shutdown timeout, forcing exit")
|
||||
logger.Warn("Shutdown timeout, forcing exit")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -26,232 +24,403 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidKeyLength = errors.New("key must be 32 bytes")
|
||||
errInvalidKeyStringLength = errors.New("key string length must be 32")
|
||||
errNoConnectedPeers = errors.New("no connected peers available")
|
||||
// ErrKeySize is returned when the key size is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeyStringLength is returned when the key string length is not 32.
|
||||
ErrKeyStringLength = errors.New("key string length must be 32")
|
||||
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
|
||||
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
|
||||
// ErrNoPeers is returned when no peers are available for sending.
|
||||
ErrNoPeers = errors.New("no peers available")
|
||||
// ErrEncryptFailed is returned when encryption fails.
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
|
||||
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
|
||||
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
|
||||
ErrUnsupportedAddressType = errors.New("unsupported address type")
|
||||
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
|
||||
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
|
||||
)
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them through WebRTC.
|
||||
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
|
||||
type Client struct {
|
||||
peers []provider.Provider
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
clientID uint32
|
||||
peerIdx atomic.Uint32
|
||||
wg sync.WaitGroup
|
||||
peers []provider.Provider
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
connections map[uint16]net.Conn
|
||||
connMu sync.RWMutex
|
||||
peerIdx atomic.Uint32
|
||||
clientID uint32
|
||||
activeClients atomic.Int32
|
||||
wg sync.WaitGroup
|
||||
dnsServer string
|
||||
}
|
||||
|
||||
const defaultSOCKSListenHost = "127.0.0.1"
|
||||
|
||||
// Run starts the client with the specified parameters.
|
||||
func Run(
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
socksPort int,
|
||||
socksHost,
|
||||
socksUser,
|
||||
localAddr string,
|
||||
dnsServer,
|
||||
socksUser string,
|
||||
socksPass string,
|
||||
) error {
|
||||
return RunWithReady(ctx, providerName, roomURL, keyHex, socksPort, socksHost, socksUser, socksPass, nil)
|
||||
return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
|
||||
}
|
||||
|
||||
// RunWithReady starts the client and calls onReady when it is listening.
|
||||
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
|
||||
func RunWithReady(
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
socksPort int,
|
||||
socksHost,
|
||||
socksUser,
|
||||
socksPass string,
|
||||
localAddr string,
|
||||
dnsServer,
|
||||
_ string,
|
||||
_ string,
|
||||
onReady func(),
|
||||
) error {
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
key, err := decodeKey(keyHex)
|
||||
cipher, err := setupCipher(keyHex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decodeKey failed: %w", err)
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return fmt.Errorf("%w: got %d", errInvalidKeyStringLength, len(keyStr))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create cipher: %w", err)
|
||||
clientIDBytes := make([]byte, 4)
|
||||
if _, err := rand.Read(clientIDBytes); err != nil {
|
||||
return fmt.Errorf("failed to generate client ID: %w", err)
|
||||
}
|
||||
clientID := binary.BigEndian.Uint32(clientIDBytes)
|
||||
|
||||
c := &Client{
|
||||
cipher: cipher,
|
||||
clientID: uint32(time.Now().UnixNano() & 0xFFFFFFFF),
|
||||
peers: make([]provider.Provider, 0, 1),
|
||||
cipher: cipher,
|
||||
connections: make(map[uint16]net.Conn),
|
||||
peers: make([]provider.Provider, 0),
|
||||
clientID: clientID,
|
||||
dnsServer: dnsServer,
|
||||
}
|
||||
|
||||
c.mux = mux.New(c.clientID, c.sendFrame)
|
||||
c.setupMux()
|
||||
|
||||
for peerID := range 1 {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, peerID, cancel); err != nil {
|
||||
const peerCount = 1
|
||||
for i := range peerCount {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
|
||||
return fmt.Errorf("addPeer failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
c.sendResetSignal()
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(runCtx, "tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen failed: %w", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
err = c.runSOCKS5(runCtx, socksHost, socksPort, socksUser, socksPass, onReady)
|
||||
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
|
||||
|
||||
if onReady != nil {
|
||||
onReady()
|
||||
}
|
||||
|
||||
go c.acceptLoop(runCtx, ln)
|
||||
|
||||
<-runCtx.Done()
|
||||
c.shutdown()
|
||||
c.wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeKey(keyHex string) ([]byte, error) {
|
||||
if keyHex == "" {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("generate random key: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Generated key: %x", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode hex key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w: got %d", errInvalidKeyLength, len(key))
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendFrame(frame []byte) error {
|
||||
waitUntilPeersCanSend(c.peers)
|
||||
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt outgoing frame: %w", err)
|
||||
}
|
||||
|
||||
peer, err := c.nextPeer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := peer.Send(encrypted); err != nil {
|
||||
return fmt.Errorf("send frame via peer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitUntilPeersCanSend(peers []provider.Provider) {
|
||||
for {
|
||||
canSend := true
|
||||
for _, peer := range peers {
|
||||
if !peer.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if canSend {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, ErrKeySize
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return nil, ErrKeyStringLength
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
// nextPeer returns the next provider for load balancing.
|
||||
//
|
||||
//nolint:ireturn
|
||||
func (c *Client) nextPeer() (provider.Provider, error) {
|
||||
switch len(c.peers) {
|
||||
case 0:
|
||||
return nil, errNoConnectedPeers
|
||||
case 1:
|
||||
return c.peers[0], nil
|
||||
default:
|
||||
return c.peers[int(c.peerIdx.Add(1)%2)], nil
|
||||
}
|
||||
func (c *Client) setupMux() {
|
||||
c.mux = mux.New(c.clientID, func(frame []byte) error {
|
||||
for {
|
||||
canSend := true
|
||||
for _, peer := range c.peers {
|
||||
if !peer.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canSend {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(c.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
|
||||
return c.peers[idx].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) addPeer(
|
||||
runCtx context.Context,
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL string,
|
||||
peerID int,
|
||||
cancel context.CancelFunc,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
) error {
|
||||
peer, err := provider.New(runCtx, providerName, provider.Config{
|
||||
RoomURL: roomURL,
|
||||
Name: names.Generate(),
|
||||
OnData: c.onData,
|
||||
peer, err := provider.New(ctx, providerName, provider.Config{
|
||||
RoomURL: roomURL,
|
||||
Name: names.Generate(),
|
||||
OnData: c.onData,
|
||||
DNSServer: dnsServer,
|
||||
ProxyAddr: socksProxyAddr,
|
||||
ProxyPort: socksProxyPort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer %d: %w", peerID, err)
|
||||
return fmt.Errorf("failed to create peer: %w", err)
|
||||
}
|
||||
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
|
||||
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
c.onReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
c.peers = append(c.peers, peer)
|
||||
|
||||
log.Printf("Connecting peer %d to %s...", peerID, providerName)
|
||||
if err := peer.Connect(runCtx); err != nil {
|
||||
return fmt.Errorf("connect peer %d: %w", peerID, err)
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
c.handlePeerReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
|
||||
if err := peer.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect peer: %w", err)
|
||||
}
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
logger.Infof("Peer %d connected", peerID)
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
peer.WatchConnection(runCtx)
|
||||
peer.WatchConnection(ctx)
|
||||
}()
|
||||
|
||||
// Send initial reset to clean up any stale connections for this clientID on server
|
||||
if err := c.mux.SendClientReset(); err != nil {
|
||||
logger.Warnf("Failed to send initial client reset: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) onReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
|
||||
c.connMu.Lock()
|
||||
for sid, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(c.connections, sid)
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
if dc != nil {
|
||||
c.mux.UpdateSendFunc(c.sendFrame)
|
||||
c.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(c.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
|
||||
return c.peers[idx].Send(encrypted)
|
||||
})
|
||||
c.mux.Reset()
|
||||
|
||||
if err := c.mux.SendClientReset(); err != nil {
|
||||
logger.Warnf("Failed to send client reset after reconnect: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendResetSignal() {
|
||||
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
|
||||
encrypted, err := c.cipher.Encrypt(resetFrame)
|
||||
if err != nil {
|
||||
log.Printf("Failed to encrypt reset signal: %v", err)
|
||||
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
logger.Debugf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go c.handleSOCKS5(ctx, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if err := c.socks5Handshake(conn); err != nil {
|
||||
logger.Debugf("SOCKS5 handshake failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Send(encrypted); err != nil {
|
||||
log.Printf("Failed to send reset signal to server: %v", err)
|
||||
}
|
||||
addr, port, err := c.socks5Request(conn)
|
||||
if err != nil {
|
||||
logger.Debugf("SOCKS5 request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
|
||||
sid := c.mux.OpenStream()
|
||||
c.connMu.Lock()
|
||||
c.connections[sid] = conn
|
||||
c.connMu.Unlock()
|
||||
|
||||
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
|
||||
|
||||
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
|
||||
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.activeClients.Add(1)
|
||||
c.startStreamPump(ctx, sid, conn)
|
||||
c.pumpToMux(sid, conn)
|
||||
}
|
||||
|
||||
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
|
||||
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
|
||||
reqData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal connect: %w", err)
|
||||
}
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
return fmt.Errorf("send connect: %w", err)
|
||||
}
|
||||
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
select {
|
||||
case <-dataReady:
|
||||
resp := c.mux.ReadStream(sid)
|
||||
if len(resp) > 0 && resp[0] == 0x00 {
|
||||
if _, err := conn.Write(replySuccess()); err != nil {
|
||||
return fmt.Errorf("write success: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return ErrTunnelSetupFailed
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||
}
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
|
||||
if buf[0] != 5 {
|
||||
return ErrInvalidSocks5
|
||||
}
|
||||
|
||||
methods := make([]byte, int(buf[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return fmt.Errorf("read methods: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return fmt.Errorf("write response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", 0, fmt.Errorf("read request header: %w", err)
|
||||
}
|
||||
|
||||
if buf[0] != 5 || buf[1] != 1 {
|
||||
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
|
||||
}
|
||||
|
||||
addr, err := c.readSocks5Addr(conn, buf[3])
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
return "", 0, fmt.Errorf("read port: %w", err)
|
||||
}
|
||||
port := int(binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
|
||||
switch addrType {
|
||||
case 1: // IPv4
|
||||
ip := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", fmt.Errorf("read ipv4: %w", err)
|
||||
}
|
||||
return net.IP(ip).String(), nil
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", fmt.Errorf("read domain len: %w", err)
|
||||
}
|
||||
domain := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, domain); err != nil {
|
||||
return "", fmt.Errorf("read domain: %w", err)
|
||||
}
|
||||
return string(domain), nil
|
||||
case 4: // IPv6
|
||||
ip := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", fmt.Errorf("read ipv6: %w", err)
|
||||
}
|
||||
return net.IP(ip).String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
@@ -264,347 +433,100 @@ func (c *Client) onData(data []byte) {
|
||||
c.mux.HandleFrame(plaintext)
|
||||
}
|
||||
|
||||
func (c *Client) runSOCKS5(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
port int,
|
||||
username,
|
||||
password string,
|
||||
onReady func(),
|
||||
) error {
|
||||
if host == "" {
|
||||
host = defaultSOCKSListenHost
|
||||
}
|
||||
|
||||
listenAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
var lc net.ListenConfig
|
||||
listener, err := lc.Listen(ctx, "tcp", listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", listenAddr, err)
|
||||
}
|
||||
|
||||
log.Printf("SOCKS5 proxy listening on %s (auth=%v)", listenAddr, username != "")
|
||||
if onReady != nil {
|
||||
onReady()
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Debugf("SOCKS5 listener close error: %v", err)
|
||||
func (c *Client) shutdown() {
|
||||
c.connMu.Lock()
|
||||
for _, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
for i, peer := range c.peers {
|
||||
logger.Infof("closing peer %d", i)
|
||||
_ = peer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
|
||||
defer func() {
|
||||
c.activeClients.Add(-1)
|
||||
_ = c.mux.CloseStream(sid)
|
||||
c.connMu.Lock()
|
||||
delete(c.connections, sid)
|
||||
c.connMu.Unlock()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
totalSent := uint64(0)
|
||||
lastLog := time.Now()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.closePeers()
|
||||
return nil
|
||||
default:
|
||||
log.Printf("accept error: %v", err)
|
||||
continue
|
||||
if totalSent > 1024*1024 {
|
||||
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go c.handleSOCKS5(conn, username, password)
|
||||
}
|
||||
}
|
||||
for !c.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
func (c *Client) closePeers() {
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Close(); err != nil {
|
||||
logger.Debugf("Peer close error: %v", err)
|
||||
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
totalSent += uint64(n) //nolint:gosec
|
||||
if time.Since(lastLog) > 5*time.Second {
|
||||
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
|
||||
lastLog = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:cyclop // SOCKS5 parsing is inherently stateful and mirrors the protocol handshake.
|
||||
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("SOCKS5 connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 513)
|
||||
if !readSOCKSVersionAndMethods(conn, buf) {
|
||||
return
|
||||
}
|
||||
|
||||
nmethods := buf[1]
|
||||
if _, err := io.ReadFull(conn, buf[:nmethods]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
requireAuth := username != ""
|
||||
wantMethod := byte(0x00)
|
||||
if requireAuth {
|
||||
wantMethod = 0x02
|
||||
}
|
||||
|
||||
if !supportsMethod(buf[:nmethods], wantMethod) {
|
||||
writeResponse(conn, replyUnsupportedSOCKSMethod())
|
||||
return
|
||||
}
|
||||
writeResponse(conn, []byte{5, wantMethod})
|
||||
|
||||
if requireAuth && !authenticateSOCKSUser(conn, buf, username, password) {
|
||||
return
|
||||
}
|
||||
|
||||
addr, port, ok := readConnectTarget(conn, buf)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sid := c.mux.OpenStream()
|
||||
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
|
||||
log.Printf("sid=%d socks5 %s:%d", sid, addr, port)
|
||||
|
||||
if !c.sendConnectRequest(sid, addr, port) {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.waitConnectResponse(conn, sid) {
|
||||
return
|
||||
}
|
||||
|
||||
c.mux.ReadStream(sid)
|
||||
writeResponse(conn, replySuccess())
|
||||
c.proxyStream(conn, sid)
|
||||
}
|
||||
|
||||
func readSOCKSVersionAndMethods(conn net.Conn, buf []byte) bool {
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return buf[0] == 5
|
||||
}
|
||||
|
||||
func supportsMethod(methods []byte, wantMethod byte) bool {
|
||||
for _, method := range methods {
|
||||
if method == wantMethod {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func authenticateSOCKSUser(conn net.Conn, buf []byte, username, password string) bool {
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return false
|
||||
}
|
||||
if buf[0] != 0x01 {
|
||||
return false
|
||||
}
|
||||
|
||||
ulen := int(buf[1])
|
||||
if _, err := io.ReadFull(conn, buf[:ulen+1]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gotUser := string(buf[:ulen])
|
||||
plen := int(buf[ulen])
|
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gotPass := string(buf[:plen])
|
||||
if gotUser != username || gotPass != password {
|
||||
writeResponse(conn, replyAuthFailed())
|
||||
return false
|
||||
}
|
||||
|
||||
writeResponse(conn, replyAuthOK())
|
||||
return true
|
||||
}
|
||||
|
||||
func readConnectTarget(conn net.Conn, buf []byte) (string, uint16, bool) {
|
||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if buf[1] != 1 {
|
||||
writeResponse(conn, replyCommandNotSupported())
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
addr, ok := readTargetAddress(conn, buf, buf[3])
|
||||
if !ok {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
return addr, binary.BigEndian.Uint16(buf[:2]), true
|
||||
}
|
||||
|
||||
func readTargetAddress(conn net.Conn, buf []byte, atyp byte) (string, bool) {
|
||||
switch atyp {
|
||||
case 1:
|
||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]), true
|
||||
case 3:
|
||||
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
length := buf[0]
|
||||
if _, err := io.ReadFull(conn, buf[:length]); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(buf[:length]), true
|
||||
default:
|
||||
writeResponse(conn, replyAddressNotSupported())
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
|
||||
reqData, err := json.Marshal(struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Addr string `json:"addr"`
|
||||
Port uint16 `json:"port"`
|
||||
}{
|
||||
Cmd: "connect",
|
||||
Addr: addr,
|
||||
Port: port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Debugf("Connect request marshal error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
logger.Debugf("Connect request send error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) waitConnectResponse(conn net.Conn, sid uint16) bool {
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
timeout := time.NewTimer(10 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-dataReady:
|
||||
stream := c.mux.GetStream(sid)
|
||||
if stream == nil || len(stream.RecvBuf()) == 0 {
|
||||
writeResponse(conn, replyHostUnreachable())
|
||||
return false
|
||||
}
|
||||
case <-timeout.C:
|
||||
writeResponse(conn, replyHostUnreachable())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
//nolint:cyclop // The stream pump handles two coordinated goroutines and shutdown races in one place.
|
||||
func (c *Client) proxyStream(conn net.Conn, sid uint16) {
|
||||
done := make(chan struct{})
|
||||
streamClosed := make(chan struct{})
|
||||
|
||||
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 32768)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err := c.mux.CloseStream(sid); err != nil {
|
||||
logger.Debugf("Close stream error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer close(streamClosed)
|
||||
defer c.mux.CleanupDataChannel(sid)
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
data := c.mux.ReadStream(sid)
|
||||
if len(data) > 0 && !writeStreamData(conn, data) {
|
||||
return
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
_ = c.mux.CloseStream(sid)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.mux.StreamClosed(sid) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-streamClosed:
|
||||
}
|
||||
}
|
||||
|
||||
func writeStreamData(conn net.Conn, data []byte) bool {
|
||||
for len(data) > 0 {
|
||||
n, err := conn.Write(data)
|
||||
if err != nil {
|
||||
func (c *Client) canSendData() bool {
|
||||
for _, peer := range c.peers {
|
||||
if !peer.CanSend() {
|
||||
return false
|
||||
}
|
||||
data = data[n:]
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func writeResponse(conn net.Conn, response []byte) {
|
||||
if _, err := conn.Write(response); err != nil {
|
||||
logger.Debugf("SOCKS5 response write error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func replyUnsupportedSOCKSMethod() []byte {
|
||||
return []byte{5, 0xFF}
|
||||
}
|
||||
|
||||
func replyAuthFailed() []byte {
|
||||
return []byte{0x01, 0x01}
|
||||
}
|
||||
|
||||
func replyAuthOK() []byte {
|
||||
return []byte{0x01, 0x00}
|
||||
}
|
||||
|
||||
func replyCommandNotSupported() []byte {
|
||||
return []byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyAddressNotSupported() []byte {
|
||||
return []byte{5, 8, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyHostUnreachable() []byte {
|
||||
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replySuccess() []byte {
|
||||
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
func replyHostUnreachable() []byte {
|
||||
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
@@ -11,15 +11,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidKeySize = errors.New("invalid key size") //nolint:revive
|
||||
ErrCiphertextTooShort = errors.New("ciphertext too short") //nolint:revive
|
||||
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrInvalidKeySize = errors.New("invalid key size")
|
||||
// ErrCiphertextTooShort is returned when the ciphertext is shorter than the nonce size.
|
||||
ErrCiphertextTooShort = errors.New("ciphertext too short")
|
||||
)
|
||||
|
||||
type Cipher struct { //nolint:revive
|
||||
// Cipher provides AEAD encryption and decryption using ChaCha20-Poly1305.
|
||||
type Cipher struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
|
||||
// NewCipher creates a new Cipher instance with the given 32-byte key.
|
||||
func NewCipher(keyStr string) (*Cipher, error) {
|
||||
key := []byte(keyStr)
|
||||
if len(key) != chacha20poly1305.KeySize {
|
||||
return nil, ErrInvalidKeySize
|
||||
@@ -33,23 +37,26 @@ func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
|
||||
return &Cipher{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { //nolint:revive
|
||||
// Encrypt encrypts plaintext and prepends a random nonce.
|
||||
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
nonce := make([]byte, c.aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to read nonce: %w", err)
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := c.aead.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
// Seal appends the ciphertext to the nonce
|
||||
return c.aead.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { //nolint:revive
|
||||
if len(ciphertext) < c.aead.NonceSize() {
|
||||
// Decrypt decrypts ciphertext that has a nonce prepended.
|
||||
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
nonceSize := c.aead.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return nil, ErrCiphertextTooShort
|
||||
}
|
||||
|
||||
nonce := ciphertext[:c.aead.NonceSize()]
|
||||
encrypted := ciphertext[c.aead.NonceSize():]
|
||||
nonce := ciphertext[:nonceSize]
|
||||
encrypted := ciphertext[nonceSize:]
|
||||
|
||||
res, err := c.aead.Open(nil, nonce, encrypted, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,28 +1,66 @@
|
||||
package logger //nolint:revive
|
||||
// Package logger provides a simple leveled logging interface.
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var verboseEnabled atomic.Bool //nolint:gochecknoglobals
|
||||
// verboseEnabled controls whether verbose and debug logging is enabled.
|
||||
//
|
||||
//nolint:gochecknoglobals // Global log state is acceptable for CLI tools.
|
||||
var verboseEnabled atomic.Bool
|
||||
|
||||
func SetVerbose(enabled bool) { //nolint:revive
|
||||
// SetVerbose enables or disables verbose/debug logging.
|
||||
func SetVerbose(enabled bool) {
|
||||
verboseEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
func IsVerbose() bool { //nolint:revive
|
||||
// IsVerbose returns true if verbose logging is enabled.
|
||||
func IsVerbose() bool {
|
||||
return verboseEnabled.Load()
|
||||
}
|
||||
|
||||
func Verbosef(format string, v ...interface{}) { //nolint:revive
|
||||
// Info logs an informational message.
|
||||
func Info(v ...any) {
|
||||
log.Print(v...)
|
||||
}
|
||||
|
||||
// Infof logs a formatted informational message.
|
||||
func Infof(format string, v ...any) {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message.
|
||||
func Warn(v ...any) {
|
||||
log.Print(v...)
|
||||
}
|
||||
|
||||
// Warnf logs a formatted warning message.
|
||||
func Warnf(format string, v ...any) {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func Error(v ...any) {
|
||||
log.Print(v...)
|
||||
}
|
||||
|
||||
// Errorf logs a formatted error message.
|
||||
func Errorf(format string, v ...any) {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
|
||||
// Verbosef logs a formatted message if verbose logging is enabled.
|
||||
func Verbosef(format string, v ...any) {
|
||||
if verboseEnabled.Load() {
|
||||
log.Printf("[VERBOSE] "+format, v...)
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func Debugf(format string, v ...interface{}) { //nolint:revive
|
||||
// Debugf logs a formatted message if verbose logging is enabled.
|
||||
func Debugf(format string, v ...any) {
|
||||
if verboseEnabled.Load() {
|
||||
log.Printf("[DEBUG] "+format, v...)
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,29 +5,43 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive
|
||||
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
|
||||
ErrClientResetID = errors.New("client reset requires a non-zero client id")
|
||||
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
|
||||
ErrDataTooLarge = errors.New("data chunk too large")
|
||||
)
|
||||
|
||||
const (
|
||||
ControlStreamID uint16 = 0xFFFF //nolint:revive
|
||||
ControlLength uint16 = 0xFFFF //nolint:revive
|
||||
// HeaderSize is the size of the frame header in bytes.
|
||||
HeaderSize = 12
|
||||
|
||||
// ControlStreamID is a special stream ID used for control frames.
|
||||
ControlStreamID uint16 = 0xFFFF
|
||||
|
||||
// ControlResetClient is a control frame type used to signal a client reset.
|
||||
ControlResetClient uint32 = 1
|
||||
|
||||
// FrameTypeData is a marker for data frames.
|
||||
FrameTypeData uint16 = 0
|
||||
// FrameTypeControl is a marker for control frames.
|
||||
FrameTypeControl uint16 = 0xFFFF
|
||||
)
|
||||
|
||||
type ControlFrame struct { //nolint:revive
|
||||
// ControlFrame represents a control message between multiplexers.
|
||||
type ControlFrame struct {
|
||||
ClientID uint32
|
||||
Type uint32
|
||||
}
|
||||
|
||||
type Stream struct { //nolint:revive
|
||||
// Stream represents a single multiplexed data stream.
|
||||
type Stream struct {
|
||||
ID uint16
|
||||
ClientID uint32
|
||||
recvBuf []byte
|
||||
@@ -37,13 +51,15 @@ type Stream struct { //nolint:revive
|
||||
outOfOrder map[uint32][]byte
|
||||
}
|
||||
|
||||
func (s *Stream) RecvBuf() []byte { //nolint:revive
|
||||
// RecvBuf returns the current receive buffer content.
|
||||
func (s *Stream) RecvBuf() []byte {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.recvBuf
|
||||
}
|
||||
|
||||
type Multiplexer struct { //nolint:revive
|
||||
// Multiplexer coordinates multiple Streams over a single transport channel.
|
||||
type Multiplexer struct {
|
||||
streams map[uint16]*Stream
|
||||
nextID uint16
|
||||
clientID uint32
|
||||
@@ -55,10 +71,14 @@ type Multiplexer struct { //nolint:revive
|
||||
dataReadyMu sync.Mutex
|
||||
sendSeq map[uint16]uint32
|
||||
sendSeqMu sync.Mutex
|
||||
|
||||
// bufferCond is used to wait for space in receive buffers
|
||||
bufferCond *sync.Cond
|
||||
}
|
||||
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive
|
||||
return &Multiplexer{
|
||||
// New creates a new Multiplexer instance.
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
m := &Multiplexer{
|
||||
streams: make(map[uint16]*Stream),
|
||||
nextID: 1,
|
||||
clientID: clientID,
|
||||
@@ -68,9 +88,12 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:rev
|
||||
dataReady: make(map[uint16]chan struct{}),
|
||||
sendSeq: make(map[uint16]uint32),
|
||||
}
|
||||
m.bufferCond = sync.NewCond(&m.mu)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
|
||||
// OpenStream allocates and returns a new unique stream ID.
|
||||
func (m *Multiplexer) OpenStream() uint16 {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -93,7 +116,8 @@ func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
|
||||
// SendData fragments and sends data over a specific stream.
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
m.mu.RLock()
|
||||
stream, exists := m.streams[sid]
|
||||
m.mu.RUnlock()
|
||||
@@ -103,11 +127,6 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
|
||||
}
|
||||
|
||||
const chunkSize = 7000
|
||||
totalChunks := (len(data) + chunkSize - 1) / chunkSize
|
||||
|
||||
if totalChunks > 10 {
|
||||
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
|
||||
}
|
||||
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
@@ -122,12 +141,16 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
|
||||
m.sendSeq[sid]++
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
frame := make([]byte, 12+len(chunk))
|
||||
if len(chunk) > math.MaxUint16 {
|
||||
return ErrDataTooLarge
|
||||
}
|
||||
|
||||
frame := make([]byte, HeaderSize+len(chunk))
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
|
||||
binary.BigEndian.PutUint32(frame[8:12], seq)
|
||||
copy(frame[12:], chunk)
|
||||
copy(frame[HeaderSize:], chunk)
|
||||
|
||||
if err := m.onSend(frame); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
@@ -137,7 +160,8 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
|
||||
// CloseStream signals that a stream should be terminated.
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -149,7 +173,10 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
|
||||
delete(m.sendSeq, sid)
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
frame := make([]byte, 12)
|
||||
// Notify anyone waiting for buffer space that a stream is closed
|
||||
m.bufferCond.Broadcast()
|
||||
|
||||
frame := make([]byte, HeaderSize)
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], 0)
|
||||
@@ -161,7 +188,8 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) SendClientReset() error { //nolint:revive
|
||||
// SendClientReset sends a control frame to reset all streams for this client.
|
||||
func (m *Multiplexer) SendClientReset() error {
|
||||
if m.clientID == 0 {
|
||||
return ErrClientResetID
|
||||
}
|
||||
@@ -171,23 +199,25 @@ func (m *Multiplexer) SendClientReset() error { //nolint:revive
|
||||
return nil
|
||||
}
|
||||
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive
|
||||
frame := make([]byte, 12)
|
||||
// BuildControlFrame constructs a raw control frame.
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
frame := make([]byte, HeaderSize)
|
||||
binary.BigEndian.PutUint32(frame[0:4], clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
|
||||
binary.BigEndian.PutUint16(frame[6:8], ControlLength)
|
||||
binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control
|
||||
binary.BigEndian.PutUint32(frame[8:12], controlType)
|
||||
return frame
|
||||
}
|
||||
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
|
||||
if len(frame) < 12 {
|
||||
// ParseControlFrame attempts to extract control information from a frame.
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
if len(frame) < HeaderSize {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
sid := binary.BigEndian.Uint16(frame[4:6])
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
if sid != ControlStreamID || length != ControlLength {
|
||||
if sid != ControlStreamID || length != 0xFFFF {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
@@ -197,14 +227,15 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
|
||||
// HandleFrame processes an incoming frame from the transport.
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if ok {
|
||||
m.handleControlFrame(control)
|
||||
return
|
||||
}
|
||||
|
||||
if len(frame) < 12 {
|
||||
if len(frame) < HeaderSize {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -218,11 +249,11 @@ func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
|
||||
return
|
||||
}
|
||||
|
||||
if len(frame) < 12+int(length) {
|
||||
if len(frame) < HeaderSize+int(length) {
|
||||
return
|
||||
}
|
||||
|
||||
m.processDataFrame(sid, clientID, seq, frame[12:12+length])
|
||||
m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)])
|
||||
}
|
||||
|
||||
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
|
||||
@@ -230,6 +261,7 @@ func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
|
||||
defer m.mu.Unlock()
|
||||
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,6 +311,7 @@ func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
|
||||
stream.closed = false
|
||||
stream.nextSeq = 0
|
||||
stream.outOfOrder = make(map[uint32][]byte)
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
return stream
|
||||
}
|
||||
@@ -319,7 +352,8 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
|
||||
// ResetClient closes and removes all streams associated with a client ID.
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -329,6 +363,7 @@ func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
|
||||
delete(m.streams, streamSid)
|
||||
}
|
||||
}
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
|
||||
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
|
||||
@@ -340,13 +375,13 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
|
||||
if len(stream.recvBuf)+need <= m.maxBufferSize {
|
||||
return stream
|
||||
}
|
||||
m.mu.Unlock()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
m.mu.Lock()
|
||||
// Wait for space to become available
|
||||
m.bufferCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
|
||||
// ReadStream retrieves and clears the current receive buffer for a stream.
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -357,10 +392,15 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
|
||||
|
||||
data := stream.recvBuf
|
||||
stream.recvBuf = make([]byte, 0)
|
||||
|
||||
// Notify producers that space is now available
|
||||
m.bufferCond.Broadcast()
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
|
||||
// StreamClosed returns true if the stream is closed or doesn't exist.
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
@@ -368,7 +408,8 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
|
||||
return !exists || stream.closed
|
||||
}
|
||||
|
||||
func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
|
||||
// GetStreams returns a list of all active stream IDs.
|
||||
func (m *Multiplexer) GetStreams() []uint16 {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
@@ -379,13 +420,15 @@ func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
|
||||
return sids
|
||||
}
|
||||
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive
|
||||
// GetStream returns the Stream object for a given ID.
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.streams[sid]
|
||||
}
|
||||
|
||||
func (m *Multiplexer) Reset() { //nolint:revive
|
||||
// Reset clears all multiplexer state and closes all streams.
|
||||
func (m *Multiplexer) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -399,16 +442,20 @@ func (m *Multiplexer) Reset() { //nolint:revive
|
||||
m.sendSeqMu.Lock()
|
||||
m.sendSeq = make(map[uint16]uint32)
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive
|
||||
// UpdateSendFunc updates the function used to transmit raw frames.
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onSend = onSend
|
||||
}
|
||||
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
|
||||
// WaitForData returns a channel that signals when new data is available for a stream.
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
@@ -418,7 +465,8 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
|
||||
return m.dataReady[sid]
|
||||
}
|
||||
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive
|
||||
// CleanupDataChannel removes the data notification channel for a stream.
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
// Package provider defines common errors for WebRTC providers.
|
||||
package provider
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrProviderNotFound is returned when the requested provider is not registered.
|
||||
ErrProviderNotFound = errors.New("provider not found")
|
||||
// ErrDataChannelTimeout is returned when the DataChannel fails to open in time.
|
||||
ErrDataChannelTimeout = errors.New("datachannel timeout")
|
||||
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
|
||||
ErrDataChannelNotReady = errors.New("datachannel not ready")
|
||||
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
|
||||
ErrSendQueueClosed = errors.New("send queue closed")
|
||||
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
|
||||
ErrSendQueueTimeout = errors.New("send queue timeout")
|
||||
)
|
||||
@@ -3,6 +3,7 @@ package jazz
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
@@ -533,6 +534,23 @@ func (p *Peer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
)
|
||||
|
||||
// AddVideoTrack adds a video track to the publisher peer connection.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.pcPub == nil {
|
||||
return nil, ErrPublisherNotInitialized
|
||||
}
|
||||
sender, err := p.pcPub.AddTrack(track)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add track: %w", err)
|
||||
}
|
||||
return sender, nil
|
||||
}
|
||||
|
||||
// SetReconnectCallback sets the callback for reconnection events.
|
||||
func (p *Peer) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
|
||||
p.onReconnect = cb
|
||||
|
||||
@@ -72,3 +72,8 @@ func (j *jazzProvider) GetSendQueue() chan []byte {
|
||||
func (j *jazzProvider) GetBufferedAmount() uint64 {
|
||||
return j.peer.GetBufferedAmount()
|
||||
}
|
||||
|
||||
// AddVideoTrack adds a video track to the jazz connection.
|
||||
func (j *jazzProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
return j.peer.AddVideoTrack(track)
|
||||
}
|
||||
|
||||
@@ -3,11 +3,27 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrProviderNotFound is returned when a requested provider is not registered.
|
||||
ErrProviderNotFound = errors.New("provider not found")
|
||||
// ErrDataChannelTimeout is returned when the DataChannel fails to open within the timeout period.
|
||||
ErrDataChannelTimeout = errors.New("datachannel timeout")
|
||||
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
|
||||
ErrDataChannelNotReady = errors.New("datachannel not ready")
|
||||
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
|
||||
ErrSendQueueClosed = errors.New("send queue closed")
|
||||
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
|
||||
ErrSendQueueTimeout = errors.New("send queue timeout")
|
||||
)
|
||||
|
||||
// Provider defines the standard interface for WebRTC connection handlers.
|
||||
//
|
||||
//nolint:interfacebloat // All methods are necessary for provider abstraction.
|
||||
type Provider interface {
|
||||
Connect(ctx context.Context) error
|
||||
Send(data []byte) error
|
||||
@@ -19,6 +35,9 @@ type Provider interface {
|
||||
CanSend() bool
|
||||
GetSendQueue() chan []byte
|
||||
GetBufferedAmount() uint64
|
||||
|
||||
// AddVideoTrack adds a video track to the connection.
|
||||
AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error)
|
||||
}
|
||||
|
||||
// Config holds common configuration for all providers.
|
||||
@@ -34,7 +53,9 @@ type Config struct {
|
||||
// Factory is a function that creates a new Provider instance.
|
||||
type Factory func(ctx context.Context, cfg Config) (Provider, error)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
// registry holds all registered provider factories.
|
||||
//
|
||||
//nolint:gochecknoglobals // Global registry is required for provider discovery.
|
||||
var registry = make(map[string]Factory)
|
||||
|
||||
// Register adds a new provider factory to the registry.
|
||||
|
||||
@@ -1160,3 +1160,20 @@ func (p *Peer) CanSend() bool {
|
||||
}
|
||||
return len(p.sendQueue) < 4000
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
)
|
||||
|
||||
// AddVideoTrack adds a video track to the publisher peer connection.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.pcPub == nil {
|
||||
return nil, ErrPublisherNotInitialized
|
||||
}
|
||||
sender, err := p.pcPub.AddTrack(track)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add track: %w", err)
|
||||
}
|
||||
return sender, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ type telemostProvider struct {
|
||||
peer *Peer
|
||||
}
|
||||
|
||||
// New creates a new Yandex Telemost provider instance.
|
||||
// New creates a new Telemost provider instance.
|
||||
func New(ctx context.Context, cfg provider.Config) (provider.Provider, error) {
|
||||
peer, err := NewPeer(ctx, cfg.RoomURL, cfg.Name, cfg.OnData)
|
||||
if err != nil {
|
||||
@@ -72,3 +72,9 @@ func (t *telemostProvider) GetSendQueue() chan []byte {
|
||||
func (t *telemostProvider) GetBufferedAmount() uint64 {
|
||||
return t.peer.GetBufferedAmount()
|
||||
}
|
||||
|
||||
// AddVideoTrack adds a video track to the telemost connection.
|
||||
func (t *telemostProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
return t.peer.AddVideoTrack(track)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,8 +18,14 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errPeerClosed = errors.New("peer closed")
|
||||
errSendQueueFull = errors.New("send queue full")
|
||||
// ErrPeerClosed is returned when an operation is attempted on a closed peer.
|
||||
ErrPeerClosed = errors.New("peer closed")
|
||||
// ErrSendQueueFull is returned when the transmission queue is full.
|
||||
ErrSendQueueFull = errors.New("send queue full")
|
||||
// ErrLiveKitNotConnected is returned when the LiveKit room is not connected.
|
||||
ErrLiveKitNotConnected = errors.New("livekit room not connected")
|
||||
// ErrVideoNotSupported is returned when video tracks are not supported by this provider.
|
||||
ErrVideoNotSupported = errors.New("video tracks not supported yet in wbstream")
|
||||
)
|
||||
|
||||
// Peer represents a WB Stream WebRTC connection using LiveKit.
|
||||
@@ -137,13 +143,13 @@ func (p *Peer) processSendQueue() {
|
||||
// Send transmits data to the room.
|
||||
func (p *Peer) Send(data []byte) error {
|
||||
if p.closed.Load() {
|
||||
return errPeerClosed
|
||||
return ErrPeerClosed
|
||||
}
|
||||
select {
|
||||
case p.sendQueue <- data:
|
||||
return nil
|
||||
default:
|
||||
return errSendQueueFull
|
||||
return ErrSendQueueFull
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,3 +199,19 @@ func (p *Peer) GetSendQueue() chan []byte {
|
||||
func (p *Peer) GetBufferedAmount() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// AddVideoTrack adds a video track to the LiveKit room.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.room == nil || p.room.LocalParticipant == nil {
|
||||
return nil, ErrLiveKitNotConnected
|
||||
}
|
||||
|
||||
_, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: "video",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
|
||||
return nil, ErrVideoNotSupported
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type wbstreamProvider struct {
|
||||
type wbStreamProvider struct {
|
||||
peer *Peer
|
||||
}
|
||||
|
||||
@@ -20,55 +20,61 @@ func New(ctx context.Context, cfg provider.Config) (provider.Provider, error) {
|
||||
return nil, fmt.Errorf("create wbstream peer: %w", err)
|
||||
}
|
||||
|
||||
return &wbstreamProvider{peer: peer}, nil
|
||||
return &wbStreamProvider{peer: peer}, nil
|
||||
}
|
||||
|
||||
// Connect starts the provider connection.
|
||||
func (w *wbstreamProvider) Connect(ctx context.Context) error {
|
||||
func (w *wbStreamProvider) Connect(ctx context.Context) error {
|
||||
return w.peer.Connect(ctx)
|
||||
}
|
||||
|
||||
// Send transmits data to the room.
|
||||
func (w *wbstreamProvider) Send(data []byte) error {
|
||||
func (w *wbStreamProvider) Send(data []byte) error {
|
||||
return w.peer.Send(data)
|
||||
}
|
||||
|
||||
// Close terminates the provider connection.
|
||||
func (w *wbstreamProvider) Close() error {
|
||||
func (w *wbStreamProvider) Close() error {
|
||||
return w.peer.Close()
|
||||
}
|
||||
|
||||
// SetReconnectCallback sets the function to call on reconnection.
|
||||
func (w *wbstreamProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
|
||||
func (w *wbStreamProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
|
||||
w.peer.SetReconnectCallback(cb)
|
||||
}
|
||||
|
||||
// SetShouldReconnect sets the function to determine if reconnection should occur.
|
||||
func (w *wbstreamProvider) SetShouldReconnect(fn func() bool) {
|
||||
func (w *wbStreamProvider) SetShouldReconnect(fn func() bool) {
|
||||
w.peer.SetShouldReconnect(fn)
|
||||
}
|
||||
|
||||
// SetEndedCallback sets the function to call when the session ends.
|
||||
func (w *wbstreamProvider) SetEndedCallback(cb func(string)) {
|
||||
func (w *wbStreamProvider) SetEndedCallback(cb func(string)) {
|
||||
w.peer.SetEndedCallback(cb)
|
||||
}
|
||||
|
||||
// WatchConnection monitors the provider connection state.
|
||||
func (w *wbstreamProvider) WatchConnection(ctx context.Context) {
|
||||
func (w *wbStreamProvider) WatchConnection(ctx context.Context) {
|
||||
w.peer.WatchConnection(ctx)
|
||||
}
|
||||
|
||||
// CanSend checks if the provider is ready to transmit data.
|
||||
func (w *wbstreamProvider) CanSend() bool {
|
||||
func (w *wbStreamProvider) CanSend() bool {
|
||||
return w.peer.CanSend()
|
||||
}
|
||||
|
||||
// GetSendQueue returns the data transmission queue.
|
||||
func (w *wbstreamProvider) GetSendQueue() chan []byte {
|
||||
func (w *wbStreamProvider) GetSendQueue() chan []byte {
|
||||
return w.peer.GetSendQueue()
|
||||
}
|
||||
|
||||
// GetBufferedAmount returns the current WebRTC buffered amount.
|
||||
func (w *wbstreamProvider) GetBufferedAmount() uint64 {
|
||||
func (w *wbStreamProvider) GetBufferedAmount() uint64 {
|
||||
return w.peer.GetBufferedAmount()
|
||||
}
|
||||
|
||||
// AddVideoTrack adds a video track to the wbstream connection.
|
||||
func (w *wbStreamProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
return w.peer.AddVideoTrack(track)
|
||||
}
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ func (s *Server) addPeer(
|
||||
}
|
||||
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
|
||||
logger.Infof("Server peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
s.peers = append(s.peers, peer)
|
||||
@@ -214,11 +214,11 @@ func (s *Server) addPeer(
|
||||
s.handlePeerReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
log.Printf("Connecting peer %d to %s...", peerID, providerName)
|
||||
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
|
||||
if err := peer.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect peer: %w", err)
|
||||
}
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
logger.Infof("Peer %d connected", peerID)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
@@ -229,7 +229,7 @@ func (s *Server) addPeer(
|
||||
}
|
||||
|
||||
func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
|
||||
|
||||
s.connMu.Lock()
|
||||
for sid, conn := range s.connections {
|
||||
@@ -303,7 +303,7 @@ func (s *Server) onData(data []byte) {
|
||||
}
|
||||
|
||||
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
|
||||
log.Printf("Received reset signal from client (clientID=%d)", control.ClientID)
|
||||
logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID)
|
||||
s.closeClientConnections(control.ClientID)
|
||||
}
|
||||
|
||||
@@ -350,7 +350,7 @@ func (s *Server) shutdown() {
|
||||
s.connMu.Unlock()
|
||||
|
||||
for i, peer := range s.peers {
|
||||
log.Printf("closing peer %d", i)
|
||||
logger.Infof("closing peer %d", i)
|
||||
_ = peer.Close()
|
||||
}
|
||||
}
|
||||
@@ -374,7 +374,7 @@ func (s *Server) processMuxStreams(ctx context.Context) {
|
||||
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
log.Printf("sid=%d connect %s:%d", sid, req.Addr, req.Port)
|
||||
logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port)
|
||||
s.closeStreamConnection(sid)
|
||||
go s.handleConnect(ctx, sid, req)
|
||||
}
|
||||
@@ -437,7 +437,7 @@ func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectReque
|
||||
dialElapsed := time.Since(dialStart)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
|
||||
logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
return
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectReque
|
||||
s.connections[sid] = conn
|
||||
s.connMu.Unlock()
|
||||
|
||||
log.Printf("sid=%d connected %s in %v", sid, addr, dialElapsed)
|
||||
logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed)
|
||||
|
||||
s.activeClients.Add(1)
|
||||
_ = s.mux.SendData(sid, []byte{0x00})
|
||||
@@ -504,7 +504,7 @@ func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if totalSent > 1024*1024 {
|
||||
log.Printf("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
|
||||
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -519,7 +519,7 @@ func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
|
||||
|
||||
totalSent += uint64(n) //nolint:gosec
|
||||
if time.Since(lastLog) > 5*time.Second {
|
||||
log.Printf("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
|
||||
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
|
||||
lastLog = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package mobile
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -111,7 +112,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er
|
||||
"telemost",
|
||||
roomURL,
|
||||
keyHex,
|
||||
socksPort,
|
||||
fmt.Sprintf("127.0.0.1:%d", socksPort),
|
||||
"",
|
||||
socksUser,
|
||||
socksPass,
|
||||
|
||||
Reference in New Issue
Block a user