Merge pull request #24 from openlibrecommunity/refactor/all

Refactor/all
This commit is contained in:
zarazaex
2026-04-20 05:49:20 +03:00
committed by GitHub
16 changed files with 658 additions and 570 deletions

View File

@@ -6,7 +6,6 @@ import (
"errors"
"flag"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
@@ -46,7 +45,7 @@ var (
func main() {
if err := run(); err != nil {
log.Print(err)
logger.Error(err)
os.Exit(1)
}
}
@@ -83,7 +82,7 @@ func run() error {
select {
case <-sigCh:
log.Println("Shutting down gracefully...")
logger.Info("Shutting down gracefully...")
cancel()
return waitForShutdown(errCh)
case err := <-errCh:
@@ -112,12 +111,8 @@ func parseFlags() config {
func configureLogging(debug bool) {
if debug {
log.SetFlags(log.Ltime | log.Lshortfile)
logger.SetVerbose(true)
return
}
log.SetFlags(log.Ltime)
}
func validateConfig(cfg config) error {
@@ -135,7 +130,7 @@ func validateConfig(cfg config) error {
return errProviderRequired
case !validProvider:
return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, available)
case cfg.roomID == "":
case cfg.roomID == "" && cfg.provider != "jazz":
return errRoomIDRequired
case cfg.mode != "srv" && cfg.mode != "cnc":
return errModeRequired
@@ -187,8 +182,8 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
cfg.provider,
roomURL,
cfg.keyHex,
cfg.socksPort,
cfg.socksHost,
fmt.Sprintf("%s:%d", cfg.socksHost, cfg.socksPort),
cfg.dnsServer,
"",
"",
)
@@ -220,11 +215,11 @@ func waitForShutdown(errCh <-chan error) error {
select {
case err := <-done:
if err == nil {
log.Println("Shutdown complete")
logger.Info("Shutdown complete")
}
return err
case <-time.After(5 * time.Second):
log.Println("Shutdown timeout, forcing exit")
logger.Warn("Shutdown timeout, forcing exit")
return nil
}
}

View File

@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -26,232 +24,403 @@ import (
)
var (
errInvalidKeyLength = errors.New("key must be 32 bytes")
errInvalidKeyStringLength = errors.New("key string length must be 32")
errNoConnectedPeers = errors.New("no connected peers available")
// ErrKeySize is returned when the key size is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
// ErrNoPeers is returned when no peers are available for sending.
ErrNoPeers = errors.New("no peers available")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
ErrUnsupportedAddressType = errors.New("unsupported address type")
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
)
// Client handles local SOCKS5 connections and tunnels them through WebRTC.
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
type Client struct {
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
clientID uint32
peerIdx atomic.Uint32
wg sync.WaitGroup
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
}
const defaultSOCKSListenHost = "127.0.0.1"
// Run starts the client with the specified parameters.
func Run(
ctx context.Context,
providerName,
roomURL,
keyHex string,
socksPort int,
socksHost,
socksUser,
localAddr string,
dnsServer,
socksUser string,
socksPass string,
) error {
return RunWithReady(ctx, providerName, roomURL, keyHex, socksPort, socksHost, socksUser, socksPass, nil)
return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
}
// RunWithReady starts the client and calls onReady when it is listening.
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
func RunWithReady(
ctx context.Context,
providerName,
roomURL,
keyHex string,
socksPort int,
socksHost,
socksUser,
socksPass string,
localAddr string,
dnsServer,
_ string,
_ string,
onReady func(),
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
key, err := decodeKey(keyHex)
cipher, err := setupCipher(keyHex)
if err != nil {
return fmt.Errorf("decodeKey failed: %w", err)
return fmt.Errorf("setupCipher failed: %w", err)
}
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("%w: got %d", errInvalidKeyStringLength, len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return fmt.Errorf("create cipher: %w", err)
clientIDBytes := make([]byte, 4)
if _, err := rand.Read(clientIDBytes); err != nil {
return fmt.Errorf("failed to generate client ID: %w", err)
}
clientID := binary.BigEndian.Uint32(clientIDBytes)
c := &Client{
cipher: cipher,
clientID: uint32(time.Now().UnixNano() & 0xFFFFFFFF),
peers: make([]provider.Provider, 0, 1),
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]provider.Provider, 0),
clientID: clientID,
dnsServer: dnsServer,
}
c.mux = mux.New(c.clientID, c.sendFrame)
c.setupMux()
for peerID := range 1 {
if err := c.addPeer(runCtx, providerName, roomURL, peerID, cancel); err != nil {
const peerCount = 1
for i := range peerCount {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
}
}
time.Sleep(100 * time.Millisecond)
c.sendResetSignal()
lc := net.ListenConfig{}
ln, err := lc.Listen(runCtx, "tcp", localAddr)
if err != nil {
return fmt.Errorf("listen failed: %w", err)
}
defer func() { _ = ln.Close() }()
err = c.runSOCKS5(runCtx, socksHost, socksPort, socksUser, socksPass, onReady)
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
if onReady != nil {
onReady()
}
go c.acceptLoop(runCtx, ln)
<-runCtx.Done()
c.shutdown()
c.wg.Wait()
return err
}
func decodeKey(keyHex string) ([]byte, error) {
if keyHex == "" {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generate random key: %w", err)
}
log.Printf("Generated key: %x", key)
return key, nil
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("decode hex key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w: got %d", errInvalidKeyLength, len(key))
}
return key, nil
}
func (c *Client) sendFrame(frame []byte) error {
waitUntilPeersCanSend(c.peers)
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("encrypt outgoing frame: %w", err)
}
peer, err := c.nextPeer()
if err != nil {
return err
}
if err := peer.Send(encrypted); err != nil {
return fmt.Errorf("send frame via peer: %w", err)
}
return nil
}
func waitUntilPeersCanSend(peers []provider.Provider) {
for {
canSend := true
for _, peer := range peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
return
}
time.Sleep(10 * time.Millisecond)
func setupCipher(keyHex string) (*crypto.Cipher, error) {
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, ErrKeySize
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, ErrKeyStringLength
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
// nextPeer returns the next provider for load balancing.
//
//nolint:ireturn
func (c *Client) nextPeer() (provider.Provider, error) {
switch len(c.peers) {
case 0:
return nil, errNoConnectedPeers
case 1:
return c.peers[0], nil
default:
return c.peers[int(c.peerIdx.Add(1)%2)], nil
}
func (c *Client) setupMux() {
c.mux = mux.New(c.clientID, func(frame []byte) error {
for {
canSend := true
for _, peer := range c.peers {
if !peer.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
}
func (c *Client) addPeer(
runCtx context.Context,
ctx context.Context,
providerName,
roomURL string,
peerID int,
cancel context.CancelFunc,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
peer, err := provider.New(runCtx, providerName, provider.Config{
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
peer, err := provider.New(ctx, providerName, provider.Config{
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
DNSServer: dnsServer,
ProxyAddr: socksProxyAddr,
ProxyPort: socksProxyPort,
})
if err != nil {
return fmt.Errorf("create peer %d: %w", peerID, err)
return fmt.Errorf("failed to create peer: %w", err)
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Client peer %d reported conference end: %s", peerID, reason)
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
cancel()
})
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.onReconnect(peerID, dc)
})
c.peers = append(c.peers, peer)
log.Printf("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(runCtx); err != nil {
return fmt.Errorf("connect peer %d: %w", peerID, err)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.handlePeerReconnect(peerID, dc)
})
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
}
log.Printf("Peer %d connected", peerID)
logger.Infof("Peer %d connected", peerID)
c.wg.Add(1)
go func() {
defer c.wg.Done()
peer.WatchConnection(runCtx)
peer.WatchConnection(ctx)
}()
// Send initial reset to clean up any stale connections for this clientID on server
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send initial client reset: %v", err)
}
return nil
}
func (c *Client) onReconnect(peerID int, dc *webrtc.DataChannel) {
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
c.connMu.Lock()
for sid, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
delete(c.connections, sid)
}
c.connMu.Unlock()
if dc != nil {
c.mux.UpdateSendFunc(c.sendFrame)
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
c.mux.Reset()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after reconnect: %v", err)
}
}
}
func (c *Client) sendResetSignal() {
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
encrypted, err := c.cipher.Encrypt(resetFrame)
if err != nil {
log.Printf("Failed to encrypt reset signal: %v", err)
func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
for {
select {
case <-ctx.Done():
return
default:
conn, err := ln.Accept()
if err != nil {
logger.Debugf("Accept error: %v", err)
continue
}
go c.handleSOCKS5(ctx, conn)
}
}
}
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := c.socks5Handshake(conn); err != nil {
logger.Debugf("SOCKS5 handshake failed: %v", err)
return
}
for _, peer := range c.peers {
if err := peer.Send(encrypted); err != nil {
log.Printf("Failed to send reset signal to server: %v", err)
}
addr, port, err := c.socks5Request(conn)
if err != nil {
logger.Debugf("SOCKS5 request failed: %v", err)
return
}
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
sid := c.mux.OpenStream()
c.connMu.Lock()
c.connections[sid] = conn
c.connMu.Unlock()
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
return
}
c.activeClients.Add(1)
c.startStreamPump(ctx, sid, conn)
c.pumpToMux(sid, conn)
}
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
reqData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("marshal connect: %w", err)
}
if err := c.mux.SendData(sid, reqData); err != nil {
return fmt.Errorf("send connect: %w", err)
}
dataReady := c.mux.WaitForData(sid)
select {
case <-dataReady:
resp := c.mux.ReadStream(sid)
if len(resp) > 0 && resp[0] == 0x00 {
if _, err := conn.Write(replySuccess()); err != nil {
return fmt.Errorf("write success: %w", err)
}
} else {
_, _ = conn.Write(replyHostUnreachable())
return ErrTunnelSetupFailed
}
case <-time.After(15 * time.Second):
_, _ = conn.Write(replyHostUnreachable())
c.mux.CleanupDataChannel(sid)
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
case <-ctx.Done():
return fmt.Errorf("context cancelled: %w", ctx.Err())
}
c.mux.CleanupDataChannel(sid)
return nil
}
func (c *Client) socks5Handshake(conn net.Conn) error {
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("read header: %w", err)
}
if buf[0] != 5 {
return ErrInvalidSocks5
}
methods := make([]byte, int(buf[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("read methods: %w", err)
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return fmt.Errorf("write response: %w", err)
}
return nil
}
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", 0, fmt.Errorf("read request header: %w", err)
}
if buf[0] != 5 || buf[1] != 1 {
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
}
addr, err := c.readSocks5Addr(conn, buf[3])
if err != nil {
return "", 0, err
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
return "", 0, fmt.Errorf("read port: %w", err)
}
port := int(binary.BigEndian.Uint16(portBuf))
return addr, port, nil
}
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case 1: // IPv4
ip := make([]byte, 4)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv4: %w", err)
}
return net.IP(ip).String(), nil
case 3: // Domain
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", fmt.Errorf("read domain len: %w", err)
}
domain := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, domain); err != nil {
return "", fmt.Errorf("read domain: %w", err)
}
return string(domain), nil
case 4: // IPv6
ip := make([]byte, 16)
if _, err := io.ReadFull(conn, ip); err != nil {
return "", fmt.Errorf("read ipv6: %w", err)
}
return net.IP(ip).String(), nil
default:
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
}
}
func (c *Client) onData(data []byte) {
@@ -264,347 +433,100 @@ func (c *Client) onData(data []byte) {
c.mux.HandleFrame(plaintext)
}
func (c *Client) runSOCKS5(
ctx context.Context,
host string,
port int,
username,
password string,
onReady func(),
) error {
if host == "" {
host = defaultSOCKSListenHost
}
listenAddr := net.JoinHostPort(host, strconv.Itoa(port))
var lc net.ListenConfig
listener, err := lc.Listen(ctx, "tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", listenAddr, err)
}
log.Printf("SOCKS5 proxy listening on %s (auth=%v)", listenAddr, username != "")
if onReady != nil {
onReady()
}
go func() {
<-ctx.Done()
if err := listener.Close(); err != nil {
logger.Debugf("SOCKS5 listener close error: %v", err)
func (c *Client) shutdown() {
c.connMu.Lock()
for _, conn := range c.connections {
if conn != nil {
_ = conn.Close()
}
}
c.connMu.Unlock()
for i, peer := range c.peers {
logger.Infof("closing peer %d", i)
_ = peer.Close()
}
}
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
c.activeClients.Add(-1)
_ = c.mux.CloseStream(sid)
c.connMu.Lock()
delete(c.connections, sid)
c.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
conn, err := listener.Accept()
n, err := conn.Read(buf)
if err != nil {
select {
case <-ctx.Done():
c.closePeers()
return nil
default:
log.Printf("accept error: %v", err)
continue
if totalSent > 1024*1024 {
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
go c.handleSOCKS5(conn, username, password)
}
}
for !c.canSendData() {
time.Sleep(20 * time.Millisecond)
}
func (c *Client) closePeers() {
for _, peer := range c.peers {
if err := peer.Close(); err != nil {
logger.Debugf("Peer close error: %v", err)
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
}
//nolint:cyclop // SOCKS5 parsing is inherently stateful and mirrors the protocol handshake.
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("SOCKS5 connection close error: %v", err)
}
}()
buf := make([]byte, 513)
if !readSOCKSVersionAndMethods(conn, buf) {
return
}
nmethods := buf[1]
if _, err := io.ReadFull(conn, buf[:nmethods]); err != nil {
return
}
requireAuth := username != ""
wantMethod := byte(0x00)
if requireAuth {
wantMethod = 0x02
}
if !supportsMethod(buf[:nmethods], wantMethod) {
writeResponse(conn, replyUnsupportedSOCKSMethod())
return
}
writeResponse(conn, []byte{5, wantMethod})
if requireAuth && !authenticateSOCKSUser(conn, buf, username, password) {
return
}
addr, port, ok := readConnectTarget(conn, buf)
if !ok {
return
}
sid := c.mux.OpenStream()
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
log.Printf("sid=%d socks5 %s:%d", sid, addr, port)
if !c.sendConnectRequest(sid, addr, port) {
return
}
if !c.waitConnectResponse(conn, sid) {
return
}
c.mux.ReadStream(sid)
writeResponse(conn, replySuccess())
c.proxyStream(conn, sid)
}
func readSOCKSVersionAndMethods(conn net.Conn, buf []byte) bool {
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return false
}
return buf[0] == 5
}
func supportsMethod(methods []byte, wantMethod byte) bool {
for _, method := range methods {
if method == wantMethod {
return true
}
}
return false
}
func authenticateSOCKSUser(conn net.Conn, buf []byte, username, password string) bool {
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return false
}
if buf[0] != 0x01 {
return false
}
ulen := int(buf[1])
if _, err := io.ReadFull(conn, buf[:ulen+1]); err != nil {
return false
}
gotUser := string(buf[:ulen])
plen := int(buf[ulen])
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return false
}
gotPass := string(buf[:plen])
if gotUser != username || gotPass != password {
writeResponse(conn, replyAuthFailed())
return false
}
writeResponse(conn, replyAuthOK())
return true
}
func readConnectTarget(conn net.Conn, buf []byte) (string, uint16, bool) {
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return "", 0, false
}
if buf[1] != 1 {
writeResponse(conn, replyCommandNotSupported())
return "", 0, false
}
addr, ok := readTargetAddress(conn, buf, buf[3])
if !ok {
return "", 0, false
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return "", 0, false
}
return addr, binary.BigEndian.Uint16(buf[:2]), true
}
func readTargetAddress(conn net.Conn, buf []byte, atyp byte) (string, bool) {
switch atyp {
case 1:
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return "", false
}
return fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]), true
case 3:
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
return "", false
}
length := buf[0]
if _, err := io.ReadFull(conn, buf[:length]); err != nil {
return "", false
}
return string(buf[:length]), true
default:
writeResponse(conn, replyAddressNotSupported())
return "", false
}
}
func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
reqData, err := json.Marshal(struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port uint16 `json:"port"`
}{
Cmd: "connect",
Addr: addr,
Port: port,
})
if err != nil {
logger.Debugf("Connect request marshal error: %v", err)
return false
}
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Debugf("Connect request send error: %v", err)
return false
}
return true
}
func (c *Client) waitConnectResponse(conn net.Conn, sid uint16) bool {
dataReady := c.mux.WaitForData(sid)
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
select {
case <-dataReady:
stream := c.mux.GetStream(sid)
if stream == nil || len(stream.RecvBuf()) == 0 {
writeResponse(conn, replyHostUnreachable())
return false
}
case <-timeout.C:
writeResponse(conn, replyHostUnreachable())
return false
}
return true
}
//nolint:cyclop // The stream pump handles two coordinated goroutines and shutdown races in one place.
func (c *Client) proxyStream(conn net.Conn, sid uint16) {
done := make(chan struct{})
streamClosed := make(chan struct{})
func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
c.wg.Add(1)
go func() {
defer close(done)
buf := make([]byte, 32768)
for {
n, err := conn.Read(buf)
if err != nil {
if err := c.mux.CloseStream(sid); err != nil {
logger.Debugf("Close stream error: %v", err)
}
return
}
if err := c.mux.SendData(sid, buf[:n]); err != nil {
return
}
}
}()
go func() {
defer close(streamClosed)
defer c.mux.CleanupDataChannel(sid)
defer c.wg.Done()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-done:
case <-ctx.Done():
return
case <-ticker.C:
data := c.mux.ReadStream(sid)
if len(data) > 0 && !writeStreamData(conn, data) {
return
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = c.mux.CloseStream(sid)
return
}
}
if c.mux.StreamClosed(sid) {
return
}
}
}
}()
select {
case <-done:
case <-streamClosed:
}
}
func writeStreamData(conn net.Conn, data []byte) bool {
for len(data) > 0 {
n, err := conn.Write(data)
if err != nil {
func (c *Client) canSendData() bool {
for _, peer := range c.peers {
if !peer.CanSend() {
return false
}
data = data[n:]
}
return true
}
func writeResponse(conn net.Conn, response []byte) {
if _, err := conn.Write(response); err != nil {
logger.Debugf("SOCKS5 response write error: %v", err)
}
}
func replyUnsupportedSOCKSMethod() []byte {
return []byte{5, 0xFF}
}
func replyAuthFailed() []byte {
return []byte{0x01, 0x01}
}
func replyAuthOK() []byte {
return []byte{0x01, 0x00}
}
func replyCommandNotSupported() []byte {
return []byte{5, 7, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyAddressNotSupported() []byte {
return []byte{5, 8, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replySuccess() []byte {
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
}
func replyHostUnreachable() []byte {
return []byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0}
}

View File

@@ -11,15 +11,19 @@ import (
)
var (
ErrInvalidKeySize = errors.New("invalid key size") //nolint:revive
ErrCiphertextTooShort = errors.New("ciphertext too short") //nolint:revive
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
ErrInvalidKeySize = errors.New("invalid key size")
// ErrCiphertextTooShort is returned when the ciphertext is shorter than the nonce size.
ErrCiphertextTooShort = errors.New("ciphertext too short")
)
type Cipher struct { //nolint:revive
// Cipher provides AEAD encryption and decryption using ChaCha20-Poly1305.
type Cipher struct {
aead cipher.AEAD
}
func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
// NewCipher creates a new Cipher instance with the given 32-byte key.
func NewCipher(keyStr string) (*Cipher, error) {
key := []byte(keyStr)
if len(key) != chacha20poly1305.KeySize {
return nil, ErrInvalidKeySize
@@ -33,23 +37,26 @@ func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
return &Cipher{aead: aead}, nil
}
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { //nolint:revive
// Encrypt encrypts plaintext and prepends a random nonce.
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) {
nonce := make([]byte, c.aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to read nonce: %w", err)
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
ciphertext := c.aead.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
// Seal appends the ciphertext to the nonce
return c.aead.Seal(nonce, nonce, plaintext, nil), nil
}
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { //nolint:revive
if len(ciphertext) < c.aead.NonceSize() {
// Decrypt decrypts ciphertext that has a nonce prepended.
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
nonceSize := c.aead.NonceSize()
if len(ciphertext) < nonceSize {
return nil, ErrCiphertextTooShort
}
nonce := ciphertext[:c.aead.NonceSize()]
encrypted := ciphertext[c.aead.NonceSize():]
nonce := ciphertext[:nonceSize]
encrypted := ciphertext[nonceSize:]
res, err := c.aead.Open(nil, nonce, encrypted, nil)
if err != nil {

View File

@@ -1,28 +1,66 @@
package logger //nolint:revive
// Package logger provides a simple leveled logging interface.
package logger
import (
"log"
"sync/atomic"
)
var verboseEnabled atomic.Bool //nolint:gochecknoglobals
// verboseEnabled controls whether verbose and debug logging is enabled.
//
//nolint:gochecknoglobals // Global log state is acceptable for CLI tools.
var verboseEnabled atomic.Bool
func SetVerbose(enabled bool) { //nolint:revive
// SetVerbose enables or disables verbose/debug logging.
func SetVerbose(enabled bool) {
verboseEnabled.Store(enabled)
}
func IsVerbose() bool { //nolint:revive
// IsVerbose returns true if verbose logging is enabled.
func IsVerbose() bool {
return verboseEnabled.Load()
}
func Verbosef(format string, v ...interface{}) { //nolint:revive
// Info logs an informational message.
func Info(v ...any) {
log.Print(v...)
}
// Infof logs a formatted informational message.
func Infof(format string, v ...any) {
log.Printf(format, v...)
}
// Warn logs a warning message.
func Warn(v ...any) {
log.Print(v...)
}
// Warnf logs a formatted warning message.
func Warnf(format string, v ...any) {
log.Printf(format, v...)
}
// Error logs an error message.
func Error(v ...any) {
log.Print(v...)
}
// Errorf logs a formatted error message.
func Errorf(format string, v ...any) {
log.Printf(format, v...)
}
// Verbosef logs a formatted message if verbose logging is enabled.
func Verbosef(format string, v ...any) {
if verboseEnabled.Load() {
log.Printf("[VERBOSE] "+format, v...)
log.Printf(format, v...)
}
}
func Debugf(format string, v ...interface{}) { //nolint:revive
// Debugf logs a formatted message if verbose logging is enabled.
func Debugf(format string, v ...any) {
if verboseEnabled.Load() {
log.Printf("[DEBUG] "+format, v...)
log.Printf(format, v...)
}
}

View File

@@ -5,29 +5,43 @@ import (
"encoding/binary"
"errors"
"fmt"
"math"
"sync"
"time"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
var (
ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
ErrClientResetID = errors.New("client reset requires a non-zero client id")
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
ErrDataTooLarge = errors.New("data chunk too large")
)
const (
ControlStreamID uint16 = 0xFFFF //nolint:revive
ControlLength uint16 = 0xFFFF //nolint:revive
// HeaderSize is the size of the frame header in bytes.
HeaderSize = 12
// ControlStreamID is a special stream ID used for control frames.
ControlStreamID uint16 = 0xFFFF
// ControlResetClient is a control frame type used to signal a client reset.
ControlResetClient uint32 = 1
// FrameTypeData is a marker for data frames.
FrameTypeData uint16 = 0
// FrameTypeControl is a marker for control frames.
FrameTypeControl uint16 = 0xFFFF
)
type ControlFrame struct { //nolint:revive
// ControlFrame represents a control message between multiplexers.
type ControlFrame struct {
ClientID uint32
Type uint32
}
type Stream struct { //nolint:revive
// Stream represents a single multiplexed data stream.
type Stream struct {
ID uint16
ClientID uint32
recvBuf []byte
@@ -37,13 +51,15 @@ type Stream struct { //nolint:revive
outOfOrder map[uint32][]byte
}
func (s *Stream) RecvBuf() []byte { //nolint:revive
// RecvBuf returns the current receive buffer content.
func (s *Stream) RecvBuf() []byte {
s.mu.Lock()
defer s.mu.Unlock()
return s.recvBuf
}
type Multiplexer struct { //nolint:revive
// Multiplexer coordinates multiple Streams over a single transport channel.
type Multiplexer struct {
streams map[uint16]*Stream
nextID uint16
clientID uint32
@@ -55,10 +71,14 @@ type Multiplexer struct { //nolint:revive
dataReadyMu sync.Mutex
sendSeq map[uint16]uint32
sendSeqMu sync.Mutex
// bufferCond is used to wait for space in receive buffers
bufferCond *sync.Cond
}
func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive
return &Multiplexer{
// New creates a new Multiplexer instance.
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
m := &Multiplexer{
streams: make(map[uint16]*Stream),
nextID: 1,
clientID: clientID,
@@ -68,9 +88,12 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:rev
dataReady: make(map[uint16]chan struct{}),
sendSeq: make(map[uint16]uint32),
}
m.bufferCond = sync.NewCond(&m.mu)
return m
}
func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
// OpenStream allocates and returns a new unique stream ID.
func (m *Multiplexer) OpenStream() uint16 {
m.mu.Lock()
defer m.mu.Unlock()
@@ -93,7 +116,8 @@ func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
}
}
func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
// SendData fragments and sends data over a specific stream.
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
m.mu.RLock()
stream, exists := m.streams[sid]
m.mu.RUnlock()
@@ -103,11 +127,6 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
}
const chunkSize = 7000
totalChunks := (len(data) + chunkSize - 1) / chunkSize
if totalChunks > 10 {
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
}
for i := 0; i < len(data); i += chunkSize {
end := i + chunkSize
@@ -122,12 +141,16 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
m.sendSeq[sid]++
m.sendSeqMu.Unlock()
frame := make([]byte, 12+len(chunk))
if len(chunk) > math.MaxUint16 {
return ErrDataTooLarge
}
frame := make([]byte, HeaderSize+len(chunk))
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
binary.BigEndian.PutUint32(frame[8:12], seq)
copy(frame[12:], chunk)
copy(frame[HeaderSize:], chunk)
if err := m.onSend(frame); err != nil {
return fmt.Errorf("onSend failed: %w", err)
@@ -137,7 +160,8 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
return nil
}
func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
// CloseStream signals that a stream should be terminated.
func (m *Multiplexer) CloseStream(sid uint16) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -149,7 +173,10 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
delete(m.sendSeq, sid)
m.sendSeqMu.Unlock()
frame := make([]byte, 12)
// Notify anyone waiting for buffer space that a stream is closed
m.bufferCond.Broadcast()
frame := make([]byte, HeaderSize)
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], 0)
@@ -161,7 +188,8 @@ func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
return nil
}
func (m *Multiplexer) SendClientReset() error { //nolint:revive
// SendClientReset sends a control frame to reset all streams for this client.
func (m *Multiplexer) SendClientReset() error {
if m.clientID == 0 {
return ErrClientResetID
}
@@ -171,23 +199,25 @@ func (m *Multiplexer) SendClientReset() error { //nolint:revive
return nil
}
func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive
frame := make([]byte, 12)
// BuildControlFrame constructs a raw control frame.
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
frame := make([]byte, HeaderSize)
binary.BigEndian.PutUint32(frame[0:4], clientID)
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
binary.BigEndian.PutUint16(frame[6:8], ControlLength)
binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control
binary.BigEndian.PutUint32(frame[8:12], controlType)
return frame
}
func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
if len(frame) < 12 {
// ParseControlFrame attempts to extract control information from a frame.
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
if len(frame) < HeaderSize {
return ControlFrame{}, false
}
sid := binary.BigEndian.Uint16(frame[4:6])
length := binary.BigEndian.Uint16(frame[6:8])
if sid != ControlStreamID || length != ControlLength {
if sid != ControlStreamID || length != 0xFFFF {
return ControlFrame{}, false
}
@@ -197,14 +227,15 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
}, true
}
func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
// HandleFrame processes an incoming frame from the transport.
func (m *Multiplexer) HandleFrame(frame []byte) {
control, ok := ParseControlFrame(frame)
if ok {
m.handleControlFrame(control)
return
}
if len(frame) < 12 {
if len(frame) < HeaderSize {
return
}
@@ -218,11 +249,11 @@ func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
return
}
if len(frame) < 12+int(length) {
if len(frame) < HeaderSize+int(length) {
return
}
m.processDataFrame(sid, clientID, seq, frame[12:12+length])
m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)])
}
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
@@ -230,6 +261,7 @@ func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
defer m.mu.Unlock()
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
stream.closed = true
m.bufferCond.Broadcast()
}
}
@@ -279,6 +311,7 @@ func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
stream.closed = false
stream.nextSeq = 0
stream.outOfOrder = make(map[uint32][]byte)
m.bufferCond.Broadcast()
}
return stream
}
@@ -319,7 +352,8 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
}
}
func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
// ResetClient closes and removes all streams associated with a client ID.
func (m *Multiplexer) ResetClient(clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -329,6 +363,7 @@ func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
delete(m.streams, streamSid)
}
}
m.bufferCond.Broadcast()
}
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
@@ -340,13 +375,13 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
if len(stream.recvBuf)+need <= m.maxBufferSize {
return stream
}
m.mu.Unlock()
time.Sleep(5 * time.Millisecond)
m.mu.Lock()
// Wait for space to become available
m.bufferCond.Wait()
}
}
func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
// ReadStream retrieves and clears the current receive buffer for a stream.
func (m *Multiplexer) ReadStream(sid uint16) []byte {
m.mu.Lock()
defer m.mu.Unlock()
@@ -357,10 +392,15 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
data := stream.recvBuf
stream.recvBuf = make([]byte, 0)
// Notify producers that space is now available
m.bufferCond.Broadcast()
return data
}
func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
// StreamClosed returns true if the stream is closed or doesn't exist.
func (m *Multiplexer) StreamClosed(sid uint16) bool {
m.mu.RLock()
defer m.mu.RUnlock()
@@ -368,7 +408,8 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
return !exists || stream.closed
}
func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
// GetStreams returns a list of all active stream IDs.
func (m *Multiplexer) GetStreams() []uint16 {
m.mu.RLock()
defer m.mu.RUnlock()
@@ -379,13 +420,15 @@ func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
return sids
}
func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive
// GetStream returns the Stream object for a given ID.
func (m *Multiplexer) GetStream(sid uint16) *Stream {
m.mu.RLock()
defer m.mu.RUnlock()
return m.streams[sid]
}
func (m *Multiplexer) Reset() { //nolint:revive
// Reset clears all multiplexer state and closes all streams.
func (m *Multiplexer) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -399,16 +442,20 @@ func (m *Multiplexer) Reset() { //nolint:revive
m.sendSeqMu.Lock()
m.sendSeq = make(map[uint16]uint32)
m.sendSeqMu.Unlock()
m.bufferCond.Broadcast()
}
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive
// UpdateSendFunc updates the function used to transmit raw frames.
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
m.mu.Lock()
defer m.mu.Unlock()
m.onSend = onSend
}
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
// WaitForData returns a channel that signals when new data is available for a stream.
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
@@ -418,7 +465,8 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
return m.dataReady[sid]
}
func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive
// CleanupDataChannel removes the data notification channel for a stream.
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()

View File

@@ -1,17 +0,0 @@
// Package provider defines common errors for WebRTC providers.
package provider
import "errors"
var (
// ErrProviderNotFound is returned when the requested provider is not registered.
ErrProviderNotFound = errors.New("provider not found")
// ErrDataChannelTimeout is returned when the DataChannel fails to open in time.
ErrDataChannelTimeout = errors.New("datachannel timeout")
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
ErrDataChannelNotReady = errors.New("datachannel not ready")
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
ErrSendQueueClosed = errors.New("send queue closed")
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
ErrSendQueueTimeout = errors.New("send queue timeout")
)

View File

@@ -3,6 +3,7 @@ package jazz
import (
"context"
"errors"
"fmt"
"log"
"strings"
@@ -533,6 +534,23 @@ func (p *Peer) Close() error {
return nil
}
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
)
// AddVideoTrack adds a video track to the publisher peer connection.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pcPub == nil {
return nil, ErrPublisherNotInitialized
}
sender, err := p.pcPub.AddTrack(track)
if err != nil {
return nil, fmt.Errorf("failed to add track: %w", err)
}
return sender, nil
}
// SetReconnectCallback sets the callback for reconnection events.
func (p *Peer) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
p.onReconnect = cb

View File

@@ -72,3 +72,8 @@ func (j *jazzProvider) GetSendQueue() chan []byte {
func (j *jazzProvider) GetBufferedAmount() uint64 {
return j.peer.GetBufferedAmount()
}
// AddVideoTrack adds a video track to the jazz connection.
func (j *jazzProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
return j.peer.AddVideoTrack(track)
}

View File

@@ -3,11 +3,27 @@ package provider
import (
"context"
"errors"
"github.com/pion/webrtc/v4"
)
var (
// ErrProviderNotFound is returned when a requested provider is not registered.
ErrProviderNotFound = errors.New("provider not found")
// ErrDataChannelTimeout is returned when the DataChannel fails to open within the timeout period.
ErrDataChannelTimeout = errors.New("datachannel timeout")
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
ErrDataChannelNotReady = errors.New("datachannel not ready")
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
ErrSendQueueClosed = errors.New("send queue closed")
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
ErrSendQueueTimeout = errors.New("send queue timeout")
)
// Provider defines the standard interface for WebRTC connection handlers.
//
//nolint:interfacebloat // All methods are necessary for provider abstraction.
type Provider interface {
Connect(ctx context.Context) error
Send(data []byte) error
@@ -19,6 +35,9 @@ type Provider interface {
CanSend() bool
GetSendQueue() chan []byte
GetBufferedAmount() uint64
// AddVideoTrack adds a video track to the connection.
AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error)
}
// Config holds common configuration for all providers.
@@ -34,7 +53,9 @@ type Config struct {
// Factory is a function that creates a new Provider instance.
type Factory func(ctx context.Context, cfg Config) (Provider, error)
//nolint:gochecknoglobals
// registry holds all registered provider factories.
//
//nolint:gochecknoglobals // Global registry is required for provider discovery.
var registry = make(map[string]Factory)
// Register adds a new provider factory to the registry.

View File

@@ -1160,3 +1160,20 @@ func (p *Peer) CanSend() bool {
}
return len(p.sendQueue) < 4000
}
var (
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
)
// AddVideoTrack adds a video track to the publisher peer connection.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pcPub == nil {
return nil, ErrPublisherNotInitialized
}
sender, err := p.pcPub.AddTrack(track)
if err != nil {
return nil, fmt.Errorf("failed to add track: %w", err)
}
return sender, nil
}

View File

@@ -13,7 +13,7 @@ type telemostProvider struct {
peer *Peer
}
// New creates a new Yandex Telemost provider instance.
// New creates a new Telemost provider instance.
func New(ctx context.Context, cfg provider.Config) (provider.Provider, error) {
peer, err := NewPeer(ctx, cfg.RoomURL, cfg.Name, cfg.OnData)
if err != nil {
@@ -72,3 +72,9 @@ func (t *telemostProvider) GetSendQueue() chan []byte {
func (t *telemostProvider) GetBufferedAmount() uint64 {
return t.peer.GetBufferedAmount()
}
// AddVideoTrack adds a video track to the telemost connection.
func (t *telemostProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
return t.peer.AddVideoTrack(track)
}

View File

@@ -18,8 +18,14 @@ const (
)
var (
errPeerClosed = errors.New("peer closed")
errSendQueueFull = errors.New("send queue full")
// ErrPeerClosed is returned when an operation is attempted on a closed peer.
ErrPeerClosed = errors.New("peer closed")
// ErrSendQueueFull is returned when the transmission queue is full.
ErrSendQueueFull = errors.New("send queue full")
// ErrLiveKitNotConnected is returned when the LiveKit room is not connected.
ErrLiveKitNotConnected = errors.New("livekit room not connected")
// ErrVideoNotSupported is returned when video tracks are not supported by this provider.
ErrVideoNotSupported = errors.New("video tracks not supported yet in wbstream")
)
// Peer represents a WB Stream WebRTC connection using LiveKit.
@@ -137,13 +143,13 @@ func (p *Peer) processSendQueue() {
// Send transmits data to the room.
func (p *Peer) Send(data []byte) error {
if p.closed.Load() {
return errPeerClosed
return ErrPeerClosed
}
select {
case p.sendQueue <- data:
return nil
default:
return errSendQueueFull
return ErrSendQueueFull
}
}
@@ -193,3 +199,19 @@ func (p *Peer) GetSendQueue() chan []byte {
func (p *Peer) GetBufferedAmount() uint64 {
return 0
}
// AddVideoTrack adds a video track to the LiveKit room.
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.room == nil || p.room.LocalParticipant == nil {
return nil, ErrLiveKitNotConnected
}
_, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
Name: "video",
})
if err != nil {
return nil, fmt.Errorf("failed to publish track: %w", err)
}
return nil, ErrVideoNotSupported
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/pion/webrtc/v4"
)
type wbstreamProvider struct {
type wbStreamProvider struct {
peer *Peer
}
@@ -20,55 +20,61 @@ func New(ctx context.Context, cfg provider.Config) (provider.Provider, error) {
return nil, fmt.Errorf("create wbstream peer: %w", err)
}
return &wbstreamProvider{peer: peer}, nil
return &wbStreamProvider{peer: peer}, nil
}
// Connect starts the provider connection.
func (w *wbstreamProvider) Connect(ctx context.Context) error {
func (w *wbStreamProvider) Connect(ctx context.Context) error {
return w.peer.Connect(ctx)
}
// Send transmits data to the room.
func (w *wbstreamProvider) Send(data []byte) error {
func (w *wbStreamProvider) Send(data []byte) error {
return w.peer.Send(data)
}
// Close terminates the provider connection.
func (w *wbstreamProvider) Close() error {
func (w *wbStreamProvider) Close() error {
return w.peer.Close()
}
// SetReconnectCallback sets the function to call on reconnection.
func (w *wbstreamProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
func (w *wbStreamProvider) SetReconnectCallback(cb func(*webrtc.DataChannel)) {
w.peer.SetReconnectCallback(cb)
}
// SetShouldReconnect sets the function to determine if reconnection should occur.
func (w *wbstreamProvider) SetShouldReconnect(fn func() bool) {
func (w *wbStreamProvider) SetShouldReconnect(fn func() bool) {
w.peer.SetShouldReconnect(fn)
}
// SetEndedCallback sets the function to call when the session ends.
func (w *wbstreamProvider) SetEndedCallback(cb func(string)) {
func (w *wbStreamProvider) SetEndedCallback(cb func(string)) {
w.peer.SetEndedCallback(cb)
}
// WatchConnection monitors the provider connection state.
func (w *wbstreamProvider) WatchConnection(ctx context.Context) {
func (w *wbStreamProvider) WatchConnection(ctx context.Context) {
w.peer.WatchConnection(ctx)
}
// CanSend checks if the provider is ready to transmit data.
func (w *wbstreamProvider) CanSend() bool {
func (w *wbStreamProvider) CanSend() bool {
return w.peer.CanSend()
}
// GetSendQueue returns the data transmission queue.
func (w *wbstreamProvider) GetSendQueue() chan []byte {
func (w *wbStreamProvider) GetSendQueue() chan []byte {
return w.peer.GetSendQueue()
}
// GetBufferedAmount returns the current WebRTC buffered amount.
func (w *wbstreamProvider) GetBufferedAmount() uint64 {
func (w *wbStreamProvider) GetBufferedAmount() uint64 {
return w.peer.GetBufferedAmount()
}
// AddVideoTrack adds a video track to the wbstream connection.
func (w *wbStreamProvider) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
return w.peer.AddVideoTrack(track)
}

View File

@@ -205,7 +205,7 @@ func (s *Server) addPeer(
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
logger.Infof("Server peer %d reported conference end: %s", peerID, reason)
cancel()
})
s.peers = append(s.peers, peer)
@@ -214,11 +214,11 @@ func (s *Server) addPeer(
s.handlePeerReconnect(peerID, dc)
})
log.Printf("Connecting peer %d to %s...", peerID, providerName)
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
}
log.Printf("Peer %d connected", peerID)
logger.Infof("Peer %d connected", peerID)
s.wg.Add(1)
go func() {
@@ -229,7 +229,7 @@ func (s *Server) addPeer(
}
func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
log.Printf("peer %d reconnect event: dc=%v", peerID, dc != nil)
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
s.connMu.Lock()
for sid, conn := range s.connections {
@@ -303,7 +303,7 @@ func (s *Server) onData(data []byte) {
}
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
log.Printf("Received reset signal from client (clientID=%d)", control.ClientID)
logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID)
s.closeClientConnections(control.ClientID)
}
@@ -350,7 +350,7 @@ func (s *Server) shutdown() {
s.connMu.Unlock()
for i, peer := range s.peers {
log.Printf("closing peer %d", i)
logger.Infof("closing peer %d", i)
_ = peer.Close()
}
}
@@ -374,7 +374,7 @@ func (s *Server) processMuxStreams(ctx context.Context) {
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("sid=%d connect %s:%d", sid, req.Addr, req.Port)
logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
@@ -437,7 +437,7 @@ func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectReque
dialElapsed := time.Since(dialStart)
if err != nil {
log.Printf("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
_ = s.mux.CloseStream(sid)
return
}
@@ -446,7 +446,7 @@ func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectReque
s.connections[sid] = conn
s.connMu.Unlock()
log.Printf("sid=%d connected %s in %v", sid, addr, dialElapsed)
logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed)
s.activeClients.Add(1)
_ = s.mux.SendData(sid, []byte{0x00})
@@ -504,7 +504,7 @@ func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
log.Printf("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
@@ -519,7 +519,7 @@ func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
log.Printf("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
logger.Infof("sid=%d sent=%dMB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}

View File

@@ -5,6 +5,7 @@ package mobile
import (
"context"
"errors"
"fmt"
"log"
"sync"
"time"
@@ -111,7 +112,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er
"telemost",
roomURL,
keyHex,
socksPort,
fmt.Sprintf("127.0.0.1:%d", socksPort),
"",
socksUser,
socksPass,

View File

@@ -65,7 +65,6 @@ mage clean
# client ( podman, pre configured, easy, unix )
./script/cnc.sh
```
<div align="center">