mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-29 16:39:45 +00:00
refactor: improve SOCKS5 error handling, refactor client connection logic, and add documentation to internal packages.
This commit is contained in:
@@ -185,7 +185,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
|
||||
fmt.Sprintf("%s:%d", cfg.socksHost, cfg.socksPort),
|
||||
cfg.dnsServer,
|
||||
"",
|
||||
0,
|
||||
"",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,11 +24,22 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeySize is returned when the key size is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeyStringLength is returned when the key string length is not 32.
|
||||
ErrKeyStringLength = errors.New("key string length must be 32")
|
||||
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
|
||||
ErrNoPeers = errors.New("no peers available")
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
// ErrInvalidSocks5 is returned when the SOCKS version is not 5.
|
||||
ErrInvalidSocks5 = errors.New("invalid SOCKS5 version")
|
||||
// ErrNoPeers is returned when no peers are available for sending.
|
||||
ErrNoPeers = errors.New("no peers available")
|
||||
// ErrEncryptFailed is returned when encryption fails.
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
// ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported.
|
||||
ErrUnsupportedSocksCommand = errors.New("unsupported SOCKS5 command")
|
||||
// ErrUnsupportedAddressType is returned when a SOCKS5 address type is not supported.
|
||||
ErrUnsupportedAddressType = errors.New("unsupported address type")
|
||||
// ErrTunnelSetupFailed is returned when the tunnel cannot be established.
|
||||
ErrTunnelSetupFailed = errors.New("tunnel setup failed")
|
||||
)
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
|
||||
@@ -53,8 +64,23 @@ func Run(
|
||||
keyHex string,
|
||||
localAddr string,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
socksUser string,
|
||||
socksPass string,
|
||||
) error {
|
||||
return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
|
||||
}
|
||||
|
||||
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
|
||||
func RunWithReady(
|
||||
ctx context.Context,
|
||||
providerName,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
localAddr string,
|
||||
dnsServer,
|
||||
_ string,
|
||||
_ string,
|
||||
onReady func(),
|
||||
) error {
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
@@ -82,19 +108,24 @@ func Run(
|
||||
|
||||
const peerCount = 1
|
||||
for i := range peerCount {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, socksProxyAddr, socksProxyPort); err != nil {
|
||||
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
|
||||
return fmt.Errorf("addPeer failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", localAddr)
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(runCtx, "tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen failed: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
|
||||
|
||||
if onReady != nil {
|
||||
onReady()
|
||||
}
|
||||
|
||||
go c.acceptLoop(runCtx, ln)
|
||||
|
||||
<-runCtx.Done()
|
||||
@@ -254,7 +285,7 @@ func (c *Client) acceptLoop(ctx context.Context, ln net.Listener) {
|
||||
}
|
||||
|
||||
func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if err := c.socks5Handshake(conn); err != nil {
|
||||
logger.Debugf("SOCKS5 handshake failed: %v", err)
|
||||
@@ -274,16 +305,25 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
||||
|
||||
logger.Infof("sid=%d tunnel to %s:%d", sid, addr, port)
|
||||
|
||||
req := map[string]any{
|
||||
"cmd": "connect",
|
||||
"addr": addr,
|
||||
"port": port,
|
||||
if err := c.setupTunnel(ctx, sid, conn, addr, port); err != nil {
|
||||
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.activeClients.Add(1)
|
||||
c.startStreamPump(ctx, sid, conn)
|
||||
c.pumpToMux(sid, conn)
|
||||
}
|
||||
|
||||
func (c *Client) setupTunnel(ctx context.Context, sid uint16, conn net.Conn, addr string, port int) error {
|
||||
req := map[string]any{"cmd": "connect", "addr": addr, "port": port}
|
||||
reqData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal connect: %w", err)
|
||||
}
|
||||
reqData, _ := json.Marshal(req)
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
logger.Warnf("sid=%d send connect failed: %v", sid, err)
|
||||
return
|
||||
return fmt.Errorf("send connect: %w", err)
|
||||
}
|
||||
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
@@ -292,30 +332,27 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn) {
|
||||
resp := c.mux.ReadStream(sid)
|
||||
if len(resp) > 0 && resp[0] == 0x00 {
|
||||
if _, err := conn.Write(replySuccess()); err != nil {
|
||||
return
|
||||
return fmt.Errorf("write success: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
return ErrTunnelSetupFailed
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
return
|
||||
return fmt.Errorf("%w: timeout", ErrTunnelSetupFailed)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||
}
|
||||
c.mux.CleanupDataChannel(sid)
|
||||
|
||||
c.activeClients.Add(1)
|
||||
c.startStreamPump(ctx, sid, conn)
|
||||
c.pumpToMux(sid, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
|
||||
if buf[0] != 5 {
|
||||
@@ -324,60 +361,68 @@ func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
|
||||
methods := make([]byte, int(buf[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read methods: %w", err)
|
||||
}
|
||||
|
||||
_, err := conn.Write([]byte{5, 0})
|
||||
return err
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return fmt.Errorf("write response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", 0, err
|
||||
return "", 0, fmt.Errorf("read request header: %w", err)
|
||||
}
|
||||
|
||||
if buf[0] != 5 || buf[1] != 1 {
|
||||
return "", 0, fmt.Errorf("unsupported SOCKS5 command: %d", buf[1])
|
||||
return "", 0, fmt.Errorf("%w: cmd=%d", ErrUnsupportedSocksCommand, buf[1])
|
||||
}
|
||||
|
||||
var addr string
|
||||
switch buf[3] {
|
||||
case 1: // IPv4
|
||||
ip := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = net.IP(ip).String()
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
domain := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, domain); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = string(domain)
|
||||
case 4: // IPv6
|
||||
ip := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
addr = net.IP(ip).String()
|
||||
default:
|
||||
return "", 0, fmt.Errorf("unsupported address type: %d", buf[3])
|
||||
addr, err := c.readSocks5Addr(conn, buf[3])
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
return "", 0, err
|
||||
return "", 0, fmt.Errorf("read port: %w", err)
|
||||
}
|
||||
port := int(binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (c *Client) readSocks5Addr(conn net.Conn, addrType byte) (string, error) {
|
||||
switch addrType {
|
||||
case 1: // IPv4
|
||||
ip := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", fmt.Errorf("read ipv4: %w", err)
|
||||
}
|
||||
return net.IP(ip).String(), nil
|
||||
case 3: // Domain
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", fmt.Errorf("read domain len: %w", err)
|
||||
}
|
||||
domain := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, domain); err != nil {
|
||||
return "", fmt.Errorf("read domain: %w", err)
|
||||
}
|
||||
return string(domain), nil
|
||||
case 4: // IPv6
|
||||
ip := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, ip); err != nil {
|
||||
return "", fmt.Errorf("read ipv6: %w", err)
|
||||
}
|
||||
return net.IP(ip).String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("%w: type=%d", ErrUnsupportedAddressType, addrType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
plaintext, err := c.cipher.Decrypt(data)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidKeySize = errors.New("invalid key size")
|
||||
// ErrInvalidKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrInvalidKeySize = errors.New("invalid key size")
|
||||
// ErrCiphertextTooShort is returned when the ciphertext is shorter than the nonce size.
|
||||
ErrCiphertextTooShort = errors.New("ciphertext too short")
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package logger provides a simple leveled logging interface.
|
||||
package logger
|
||||
|
||||
import (
|
||||
@@ -5,6 +6,9 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// verboseEnabled controls whether verbose and debug logging is enabled.
|
||||
//
|
||||
//nolint:gochecknoglobals // Global log state is acceptable for CLI tools.
|
||||
var verboseEnabled atomic.Bool
|
||||
|
||||
// SetVerbose enables or disables verbose/debug logging.
|
||||
|
||||
@@ -5,35 +5,42 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
|
||||
ErrClientResetID = errors.New("client reset requires a non-zero client id")
|
||||
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
|
||||
ErrDataTooLarge = errors.New("data chunk too large")
|
||||
)
|
||||
|
||||
const (
|
||||
// Frame Header sizes
|
||||
// HeaderSize is the size of the frame header in bytes.
|
||||
HeaderSize = 12
|
||||
|
||||
// Special Stream IDs
|
||||
// ControlStreamID is a special stream ID used for control frames.
|
||||
ControlStreamID uint16 = 0xFFFF
|
||||
|
||||
// Control Frame Types
|
||||
// ControlResetClient is a control frame type used to signal a client reset.
|
||||
ControlResetClient uint32 = 1
|
||||
|
||||
// Frame Types (Internal to mux logic)
|
||||
FrameTypeData uint16 = 0
|
||||
// FrameTypeData is a marker for data frames.
|
||||
FrameTypeData uint16 = 0
|
||||
// FrameTypeControl is a marker for control frames.
|
||||
FrameTypeControl uint16 = 0xFFFF
|
||||
)
|
||||
|
||||
// ControlFrame represents a control message between multiplexers.
|
||||
type ControlFrame struct {
|
||||
ClientID uint32
|
||||
Type uint32
|
||||
}
|
||||
|
||||
// Stream represents a single multiplexed data stream.
|
||||
type Stream struct {
|
||||
ID uint16
|
||||
ClientID uint32
|
||||
@@ -44,12 +51,14 @@ type Stream struct {
|
||||
outOfOrder map[uint32][]byte
|
||||
}
|
||||
|
||||
// RecvBuf returns the current receive buffer content.
|
||||
func (s *Stream) RecvBuf() []byte {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.recvBuf
|
||||
}
|
||||
|
||||
// Multiplexer coordinates multiple Streams over a single transport channel.
|
||||
type Multiplexer struct {
|
||||
streams map[uint16]*Stream
|
||||
nextID uint16
|
||||
@@ -67,6 +76,7 @@ type Multiplexer struct {
|
||||
bufferCond *sync.Cond
|
||||
}
|
||||
|
||||
// New creates a new Multiplexer instance.
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
m := &Multiplexer{
|
||||
streams: make(map[uint16]*Stream),
|
||||
@@ -82,6 +92,7 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
return m
|
||||
}
|
||||
|
||||
// OpenStream allocates and returns a new unique stream ID.
|
||||
func (m *Multiplexer) OpenStream() uint16 {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -105,6 +116,7 @@ func (m *Multiplexer) OpenStream() uint16 {
|
||||
}
|
||||
}
|
||||
|
||||
// SendData fragments and sends data over a specific stream.
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
m.mu.RLock()
|
||||
stream, exists := m.streams[sid]
|
||||
@@ -115,11 +127,6 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
}
|
||||
|
||||
const chunkSize = 7000
|
||||
totalChunks := (len(data) + chunkSize - 1) / chunkSize
|
||||
|
||||
if totalChunks > 10 {
|
||||
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
|
||||
}
|
||||
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
@@ -134,10 +141,14 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
m.sendSeq[sid]++
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
if len(chunk) > math.MaxUint16 {
|
||||
return ErrDataTooLarge
|
||||
}
|
||||
|
||||
frame := make([]byte, HeaderSize+len(chunk))
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk)))
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
|
||||
binary.BigEndian.PutUint32(frame[8:12], seq)
|
||||
copy(frame[HeaderSize:], chunk)
|
||||
|
||||
@@ -149,6 +160,7 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream signals that a stream should be terminated.
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -176,6 +188,7 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendClientReset sends a control frame to reset all streams for this client.
|
||||
func (m *Multiplexer) SendClientReset() error {
|
||||
if m.clientID == 0 {
|
||||
return ErrClientResetID
|
||||
@@ -186,6 +199,7 @@ func (m *Multiplexer) SendClientReset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildControlFrame constructs a raw control frame.
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
frame := make([]byte, HeaderSize)
|
||||
binary.BigEndian.PutUint32(frame[0:4], clientID)
|
||||
@@ -195,6 +209,7 @@ func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
return frame
|
||||
}
|
||||
|
||||
// ParseControlFrame attempts to extract control information from a frame.
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
if len(frame) < HeaderSize {
|
||||
return ControlFrame{}, false
|
||||
@@ -212,6 +227,7 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
}, true
|
||||
}
|
||||
|
||||
// HandleFrame processes an incoming frame from the transport.
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if ok {
|
||||
@@ -336,6 +352,7 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
|
||||
}
|
||||
}
|
||||
|
||||
// ResetClient closes and removes all streams associated with a client ID.
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -363,6 +380,7 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStream retrieves and clears the current receive buffer for a stream.
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -381,6 +399,7 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
// StreamClosed returns true if the stream is closed or doesn't exist.
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
@@ -389,6 +408,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
return !exists || stream.closed
|
||||
}
|
||||
|
||||
// GetStreams returns a list of all active stream IDs.
|
||||
func (m *Multiplexer) GetStreams() []uint16 {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
@@ -400,12 +420,14 @@ func (m *Multiplexer) GetStreams() []uint16 {
|
||||
return sids
|
||||
}
|
||||
|
||||
// GetStream returns the Stream object for a given ID.
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.streams[sid]
|
||||
}
|
||||
|
||||
// Reset clears all multiplexer state and closes all streams.
|
||||
func (m *Multiplexer) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -424,6 +446,7 @@ func (m *Multiplexer) Reset() {
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
|
||||
// UpdateSendFunc updates the function used to transmit raw frames.
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -431,6 +454,7 @@ func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
|
||||
m.onSend = onSend
|
||||
}
|
||||
|
||||
// WaitForData returns a channel that signals when new data is available for a stream.
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
@@ -441,6 +465,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
return m.dataReady[sid]
|
||||
}
|
||||
|
||||
// CleanupDataChannel removes the data notification channel for a stream.
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
@@ -3,6 +3,7 @@ package jazz
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
@@ -533,12 +534,21 @@ func (p *Peer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
)
|
||||
|
||||
// AddVideoTrack adds a video track to the publisher peer connection.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.pcPub == nil {
|
||||
return nil, fmt.Errorf("publisher peer connection not initialized")
|
||||
return nil, ErrPublisherNotInitialized
|
||||
}
|
||||
return p.pcPub.AddTrack(track)
|
||||
sender, err := p.pcPub.AddTrack(track)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add track: %w", err)
|
||||
}
|
||||
return sender, nil
|
||||
}
|
||||
|
||||
// SetReconnectCallback sets the callback for reconnection events.
|
||||
|
||||
@@ -9,14 +9,21 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProviderNotFound = errors.New("provider not found")
|
||||
ErrDataChannelTimeout = errors.New("datachannel timeout")
|
||||
// ErrProviderNotFound is returned when a requested provider is not registered.
|
||||
ErrProviderNotFound = errors.New("provider not found")
|
||||
// ErrDataChannelTimeout is returned when the DataChannel fails to open within the timeout period.
|
||||
ErrDataChannelTimeout = errors.New("datachannel timeout")
|
||||
// ErrDataChannelNotReady is returned when attempting to send data before the DataChannel is open.
|
||||
ErrDataChannelNotReady = errors.New("datachannel not ready")
|
||||
ErrSendQueueClosed = errors.New("send queue closed")
|
||||
ErrSendQueueTimeout = errors.New("send queue timeout")
|
||||
// ErrSendQueueClosed is returned when attempting to send data after the send queue has been closed.
|
||||
ErrSendQueueClosed = errors.New("send queue closed")
|
||||
// ErrSendQueueTimeout is returned when the send queue is full and the timeout is reached.
|
||||
ErrSendQueueTimeout = errors.New("send queue timeout")
|
||||
)
|
||||
|
||||
// Provider defines the standard interface for WebRTC connection handlers.
|
||||
//
|
||||
//nolint:interfacebloat // All methods are necessary for provider abstraction.
|
||||
type Provider interface {
|
||||
Connect(ctx context.Context) error
|
||||
Send(data []byte) error
|
||||
@@ -47,6 +54,8 @@ type Config struct {
|
||||
type Factory func(ctx context.Context, cfg Config) (Provider, error)
|
||||
|
||||
// registry holds all registered provider factories.
|
||||
//
|
||||
//nolint:gochecknoglobals // Global registry is required for provider discovery.
|
||||
var registry = make(map[string]Factory)
|
||||
|
||||
// Register adds a new provider factory to the registry.
|
||||
|
||||
@@ -1161,10 +1161,19 @@ func (p *Peer) CanSend() bool {
|
||||
return len(p.sendQueue) < 4000
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPublisherNotInitialized is returned when the publisher peer connection is not set up.
|
||||
ErrPublisherNotInitialized = errors.New("publisher peer connection not initialized")
|
||||
)
|
||||
|
||||
// AddVideoTrack adds a video track to the publisher peer connection.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.pcPub == nil {
|
||||
return nil, fmt.Errorf("publisher peer connection not initialized")
|
||||
return nil, ErrPublisherNotInitialized
|
||||
}
|
||||
return p.pcPub.AddTrack(track)
|
||||
sender, err := p.pcPub.AddTrack(track)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add track: %w", err)
|
||||
}
|
||||
return sender, nil
|
||||
}
|
||||
|
||||
@@ -18,8 +18,14 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errPeerClosed = errors.New("peer closed")
|
||||
errSendQueueFull = errors.New("send queue full")
|
||||
// ErrPeerClosed is returned when an operation is attempted on a closed peer.
|
||||
ErrPeerClosed = errors.New("peer closed")
|
||||
// ErrSendQueueFull is returned when the transmission queue is full.
|
||||
ErrSendQueueFull = errors.New("send queue full")
|
||||
// ErrLiveKitNotConnected is returned when the LiveKit room is not connected.
|
||||
ErrLiveKitNotConnected = errors.New("livekit room not connected")
|
||||
// ErrVideoNotSupported is returned when video tracks are not supported by this provider.
|
||||
ErrVideoNotSupported = errors.New("video tracks not supported yet in wbstream")
|
||||
)
|
||||
|
||||
// Peer represents a WB Stream WebRTC connection using LiveKit.
|
||||
@@ -137,13 +143,13 @@ func (p *Peer) processSendQueue() {
|
||||
// Send transmits data to the room.
|
||||
func (p *Peer) Send(data []byte) error {
|
||||
if p.closed.Load() {
|
||||
return errPeerClosed
|
||||
return ErrPeerClosed
|
||||
}
|
||||
select {
|
||||
case p.sendQueue <- data:
|
||||
return nil
|
||||
default:
|
||||
return errSendQueueFull
|
||||
return ErrSendQueueFull
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,23 +203,15 @@ func (p *Peer) GetBufferedAmount() uint64 {
|
||||
// AddVideoTrack adds a video track to the LiveKit room.
|
||||
func (p *Peer) AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.room == nil || p.room.LocalParticipant == nil {
|
||||
return nil, fmt.Errorf("livekit room not connected")
|
||||
return nil, ErrLiveKitNotConnected
|
||||
}
|
||||
|
||||
publication, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
_, err := p.room.LocalParticipant.PublishTrack(track, &lksdk.TrackPublicationOptions{
|
||||
Name: "video",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to publish track: %w", err)
|
||||
}
|
||||
|
||||
// LiveKit SDK wraps RTPSender, but for the interface compatibility we might need to handle this differently.
|
||||
// Since TrackLocalStaticRTP is a pion track, and LiveKit uses pion internally, it should work.
|
||||
// However, LiveKit's PublishTrack doesn't return *webrtc.RTPSender directly.
|
||||
// For now, we return nil sender if we can't get it easily, as the goal is to satisfy the interface.
|
||||
if publication != nil {
|
||||
return nil, nil // TODO: extract RTPSender if needed for VideoChannel
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, ErrVideoNotSupported
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package mobile
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -111,7 +112,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er
|
||||
"telemost",
|
||||
roomURL,
|
||||
keyHex,
|
||||
socksPort,
|
||||
fmt.Sprintf("127.0.0.1:%d", socksPort),
|
||||
"",
|
||||
socksUser,
|
||||
socksPass,
|
||||
|
||||
Reference in New Issue
Block a user