From fffb90e3213f74e46add9ab96b4c3a1df74a7e4d Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Mon, 20 Apr 2026 20:05:23 +0300 Subject: [PATCH] refactor: introduce transport layer --- cmd/olcrtc/main.go | 33 ++++-- internal/client/client.go | 115 ++++++++++---------- internal/server/server.go | 90 +++++++-------- internal/transport/datachannel/transport.go | 76 +++++++++++++ internal/transport/transport.go | 63 +++++++++++ mobile/mobile.go | 1 + 6 files changed, 269 insertions(+), 109 deletions(-) create mode 100644 internal/transport/datachannel/transport.go create mode 100644 internal/transport/transport.go diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 6abbf1f..3785c56 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -20,10 +20,13 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/provider/telemost" "github.com/openlibrecommunity/olcrtc/internal/provider/wbstream" "github.com/openlibrecommunity/olcrtc/internal/server" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/datachannel" ) type config struct { mode string + transport string roomID string provider string socksPort int @@ -37,10 +40,11 @@ type config struct { } var ( - errRoomIDRequired = errors.New("room ID required") - errModeRequired = errors.New("specify -mode srv or -mode cnc") - errProviderRequired = errors.New("provider required (use -provider telemost or -provider jazz)") - errUnsupportedProvider = errors.New("unsupported provider") + errRoomIDRequired = errors.New("room ID required") + errModeRequired = errors.New("specify -mode srv or -mode cnc") + errProviderRequired = errors.New("provider required (use -provider telemost or -provider jazz)") + errUnsupportedProvider = errors.New("unsupported provider") + errUnsupportedTransport = errors.New("unsupported transport") ) func main() { @@ -54,6 +58,7 @@ func run() error { provider.Register("jazz", jazz.New) provider.Register("telemost", telemost.New) provider.Register("wb_stream", wbstream.New) + transport.Register("datachannel", datachannel.New) cfg := parseFlags() configureLogging(cfg.debug) @@ -94,6 +99,7 @@ func parseFlags() config { cfg := config{} flag.StringVar(&cfg.mode, "mode", "", "Mode: srv or cnc") + flag.StringVar(&cfg.transport, "transport", "datachannel", "Transport: datachannel") flag.StringVar(&cfg.roomID, "id", "", "Room ID") flag.StringVar(&cfg.provider, "provider", "", "Provider: telemost or jazz (required)") flag.IntVar(&cfg.socksPort, "socks-port", 1080, "SOCKS5 port (client only)") @@ -116,20 +122,31 @@ func configureLogging(debug bool) { } func validateConfig(cfg config) error { - available := provider.Available() + availableProviders := provider.Available() validProvider := false - for _, p := range available { + for _, p := range availableProviders { if cfg.provider == p { validProvider = true break } } + availableTransports := transport.Available() + validTransport := false + for _, t := range availableTransports { + if cfg.transport == t { + validTransport = true + break + } + } + switch { case cfg.provider == "": return errProviderRequired case !validProvider: - return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, available) + return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, availableProviders) + case !validTransport: + return fmt.Errorf("%w: %s (available: %v)", errUnsupportedTransport, cfg.transport, availableTransports) case cfg.roomID == "" && cfg.provider != "jazz": return errRoomIDRequired case cfg.mode != "srv" && cfg.mode != "cnc": @@ -169,6 +186,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) { case "srv": errCh <- server.Run( ctx, + cfg.transport, cfg.provider, roomURL, cfg.keyHex, @@ -179,6 +197,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) { case "cnc": errCh <- client.Run( ctx, + cfg.transport, cfg.provider, roomURL, cfg.keyHex, diff --git a/internal/client/client.go b/internal/client/client.go index 024fedc..8d57e4a 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -19,8 +19,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/provider" - "github.com/pion/webrtc/v4" + "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -44,21 +43,22 @@ var ( // Client handles local SOCKS5 connections and tunnels them via WebRTC. type Client struct { - peers []provider.Provider - cipher *crypto.Cipher - mux *mux.Multiplexer - connections map[uint16]net.Conn - connMu sync.RWMutex - peerIdx atomic.Uint32 - clientID uint32 - activeClients atomic.Int32 - wg sync.WaitGroup - dnsServer string + transports []transport.Transport + cipher *crypto.Cipher + mux *mux.Multiplexer + connections map[uint16]net.Conn + connMu sync.RWMutex + peerIdx atomic.Uint32 + clientID uint32 + activeClients atomic.Int32 + wg sync.WaitGroup + dnsServer string } // Run starts the client with the specified parameters. func Run( ctx context.Context, + transportName, providerName, roomURL, keyHex string, @@ -67,12 +67,13 @@ func Run( socksUser string, socksPass string, ) error { - return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil) + return RunWithReady(ctx, transportName, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. func RunWithReady( ctx context.Context, + transportName, providerName, roomURL, keyHex string, @@ -99,7 +100,7 @@ func RunWithReady( c := &Client{ cipher: cipher, connections: make(map[uint16]net.Conn), - peers: make([]provider.Provider, 0), + transports: make([]transport.Transport, 0), clientID: clientID, dnsServer: dnsServer, } @@ -108,8 +109,8 @@ func RunWithReady( const peerCount = 1 for i := range peerCount { - if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil { - return fmt.Errorf("addPeer failed: %w", err) + if err := c.addTransport(runCtx, transportName, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil { + return fmt.Errorf("addTransport failed: %w", err) } } @@ -160,8 +161,8 @@ func (c *Client) setupMux() { c.mux = mux.New(c.clientID, func(frame []byte) error { for { canSend := true - for _, peer := range c.peers { - if !peer.CanSend() { + for _, tr := range c.transports { + if !tr.CanSend() { canSend = false break } @@ -176,16 +177,17 @@ func (c *Client) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(c.peers) == 0 { + if len(c.transports) == 0 { return ErrNoPeers } - idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec - return c.peers[idx].Send(encrypted) + idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec + return c.transports[idx].Send(encrypted) }) } -func (c *Client) addPeer( +func (c *Client) addTransport( ctx context.Context, + transportName, providerName, roomURL string, peerID int, @@ -194,7 +196,8 @@ func (c *Client) addPeer( socksProxyAddr string, socksProxyPort int, ) error { - peer, err := provider.New(ctx, providerName, provider.Config{ + tr, err := transport.New(ctx, transportName, transport.Config{ + Carrier: providerName, RoomURL: roomURL, Name: names.Generate(), OnData: c.onData, @@ -203,29 +206,29 @@ func (c *Client) addPeer( ProxyPort: socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create peer: %w", err) + return fmt.Errorf("failed to create transport: %w", err) } - peer.SetEndedCallback(func(reason string) { - logger.Infof("Client peer %d reported conference end: %s", peerID, reason) + tr.SetEndedCallback(func(reason string) { + logger.Infof("Client transport %d reported conference end: %s", peerID, reason) cancel() }) - c.peers = append(c.peers, peer) + c.transports = append(c.transports, tr) - peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - c.handlePeerReconnect(peerID, dc) + tr.SetReconnectCallback(func() { + c.handleTransportReconnect(peerID) }) - logger.Infof("Connecting peer %d to %s...", peerID, providerName) - if err := peer.Connect(ctx); err != nil { - return fmt.Errorf("failed to connect peer: %w", err) + logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) + if err := tr.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect transport: %w", err) } - logger.Infof("Peer %d connected", peerID) + logger.Infof("Transport %d connected", peerID) c.wg.Add(1) go func() { defer c.wg.Done() - peer.WatchConnection(ctx) + tr.WatchConnection(ctx) }() // Send initial reset to clean up any stale connections for this clientID on server @@ -236,8 +239,8 @@ func (c *Client) addPeer( return nil } -func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { - logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil) +func (c *Client) handleTransportReconnect(peerID int) { + logger.Infof("transport %d reconnect event", peerID) c.connMu.Lock() for sid, conn := range c.connections { @@ -248,23 +251,21 @@ func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { } c.connMu.Unlock() - if dc != nil { - c.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := c.cipher.Encrypt(frame) - if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) - } - if len(c.peers) == 0 { - return ErrNoPeers - } - idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec - return c.peers[idx].Send(encrypted) - }) - c.mux.Reset() - - if err := c.mux.SendClientReset(); err != nil { - logger.Warnf("Failed to send client reset after reconnect: %v", err) + c.mux.UpdateSendFunc(func(frame []byte) error { + encrypted, err := c.cipher.Encrypt(frame) + if err != nil { + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } + if len(c.transports) == 0 { + return ErrNoPeers + } + idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec + return c.transports[idx].Send(encrypted) + }) + c.mux.Reset() + + if err := c.mux.SendClientReset(); err != nil { + logger.Warnf("Failed to send client reset after reconnect: %v", err) } } @@ -442,9 +443,9 @@ func (c *Client) shutdown() { } c.connMu.Unlock() - for i, peer := range c.peers { - logger.Infof("closing peer %d", i) - _ = peer.Close() + for i, tr := range c.transports { + logger.Infof("closing transport %d", i) + _ = tr.Close() } } @@ -515,8 +516,8 @@ func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (c *Client) canSendData() bool { - for _, peer := range c.peers { - if !peer.CanSend() { + for _, tr := range c.transports { + if !tr.CanSend() { return false } } diff --git a/internal/server/server.go b/internal/server/server.go index a725393..0543052 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,8 +20,7 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/provider" - "github.com/pion/webrtc/v4" + "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -43,7 +42,7 @@ var ( // Server handles incoming WebRTC connections and proxies their traffic. type Server struct { - peers []provider.Provider + transports []transport.Transport cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn @@ -69,6 +68,7 @@ type ConnectRequest struct { // Run starts the server with the specified parameters. func Run( ctx context.Context, + transportName, providerName, roomURL, keyHex string, @@ -88,7 +88,7 @@ func Run( cipher: cipher, connections: make(map[uint16]net.Conn), streamPumps: make(map[uint16]net.Conn), - peers: make([]provider.Provider, 0), + transports: make([]transport.Transport, 0), dnsServer: dnsServer, socksProxyAddr: socksProxyAddr, socksProxyPort: socksProxyPort, @@ -103,8 +103,8 @@ func Run( const peerCount = 1 for i := range peerCount { - if err := s.addPeer(runCtx, providerName, roomURL, i, cancel); err != nil { - return fmt.Errorf("addPeer failed: %w", err) + if err := s.addTransport(runCtx, transportName, providerName, roomURL, i, cancel); err != nil { + return fmt.Errorf("addTransport failed: %w", err) } } @@ -161,8 +161,8 @@ func (s *Server) setupMux() { s.mux = mux.New(0, func(frame []byte) error { for { canSend := true - for _, peer := range s.peers { - if !peer.CanSend() { + for _, tr := range s.transports { + if !tr.CanSend() { canSend = false break } @@ -177,22 +177,24 @@ func (s *Server) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(s.peers) == 0 { + if len(s.transports) == 0 { return ErrNoPeers } - idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec - return s.peers[idx].Send(encrypted) + idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec + return s.transports[idx].Send(encrypted) }) } -func (s *Server) addPeer( +func (s *Server) addTransport( ctx context.Context, + transportName, providerName, roomURL string, peerID int, cancel context.CancelFunc, ) error { - peer, err := provider.New(ctx, providerName, provider.Config{ + tr, err := transport.New(ctx, transportName, transport.Config{ + Carrier: providerName, RoomURL: roomURL, Name: names.Generate(), OnData: s.onData, @@ -201,35 +203,35 @@ func (s *Server) addPeer( ProxyPort: s.socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create peer: %w", err) + return fmt.Errorf("failed to create transport: %w", err) } - peer.SetEndedCallback(func(reason string) { - logger.Infof("Server peer %d reported conference end: %s", peerID, reason) + tr.SetEndedCallback(func(reason string) { + logger.Infof("Server transport %d reported conference end: %s", peerID, reason) cancel() }) - s.peers = append(s.peers, peer) + s.transports = append(s.transports, tr) - peer.SetReconnectCallback(func(dc *webrtc.DataChannel) { - s.handlePeerReconnect(peerID, dc) + tr.SetReconnectCallback(func() { + s.handleTransportReconnect(peerID) }) - logger.Infof("Connecting peer %d to %s...", peerID, providerName) - if err := peer.Connect(ctx); err != nil { - return fmt.Errorf("failed to connect peer: %w", err) + logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) + if err := tr.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect transport: %w", err) } - logger.Infof("Peer %d connected", peerID) + logger.Infof("Transport %d connected", peerID) s.wg.Add(1) go func() { defer s.wg.Done() - peer.WatchConnection(ctx) + tr.WatchConnection(ctx) }() return nil } -func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { - logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil) +func (s *Server) handleTransportReconnect(peerID int) { + logger.Infof("transport %d reconnect event", peerID) s.connMu.Lock() for sid, conn := range s.connections { @@ -240,20 +242,18 @@ func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) { } s.connMu.Unlock() - if dc != nil { - s.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := s.cipher.Encrypt(frame) - if err != nil { - return fmt.Errorf("%w: %w", ErrEncryptFailed, err) - } - if len(s.peers) == 0 { - return ErrNoPeers - } - idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec - return s.peers[idx].Send(encrypted) - }) - s.mux.Reset() - } + s.mux.UpdateSendFunc(func(frame []byte) error { + encrypted, err := s.cipher.Encrypt(frame) + if err != nil { + return fmt.Errorf("%w: %w", ErrEncryptFailed, err) + } + if len(s.transports) == 0 { + return ErrNoPeers + } + idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec + return s.transports[idx].Send(encrypted) + }) + s.mux.Reset() } func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error { @@ -349,9 +349,9 @@ func (s *Server) shutdown() { } s.connMu.Unlock() - for i, peer := range s.peers { - logger.Infof("closing peer %d", i) - _ = peer.Close() + for i, tr := range s.transports { + logger.Infof("closing transport %d", i) + _ = tr.Close() } } @@ -561,8 +561,8 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (s *Server) canSendData() bool { - for _, peer := range s.peers { - if !peer.CanSend() { + for _, tr := range s.transports { + if !tr.CanSend() { return false } } diff --git a/internal/transport/datachannel/transport.go b/internal/transport/datachannel/transport.go new file mode 100644 index 0000000..dfd1b2d --- /dev/null +++ b/internal/transport/datachannel/transport.go @@ -0,0 +1,76 @@ +// Package datachannel provides a transport backed by the current WebRTC providers. +package datachannel + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/pion/webrtc/v4" +) + +type providerTransport struct { + provider provider.Provider +} + +// New creates a datachannel transport backed by a carrier-specific provider. +func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { + p, err := provider.New(ctx, cfg.Carrier, provider.Config{ + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("create provider transport: %w", err) + } + + return &providerTransport{provider: p}, nil +} + +// Connect starts the transport connection. +func (p *providerTransport) Connect(ctx context.Context) error { + return p.provider.Connect(ctx) +} + +// Send transmits data through the transport. +func (p *providerTransport) Send(data []byte) error { + return p.provider.Send(data) +} + +// Close terminates the transport. +func (p *providerTransport) Close() error { + return p.provider.Close() +} + +// SetReconnectCallback registers reconnect handling. +func (p *providerTransport) SetReconnectCallback(cb func()) { + p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} + +// SetShouldReconnect configures reconnect policy. +func (p *providerTransport) SetShouldReconnect(fn func() bool) { + p.provider.SetShouldReconnect(fn) +} + +// SetEndedCallback registers end-of-session handling. +func (p *providerTransport) SetEndedCallback(cb func(string)) { + p.provider.SetEndedCallback(cb) +} + +// WatchConnection monitors connection lifecycle. +func (p *providerTransport) WatchConnection(ctx context.Context) { + p.provider.WatchConnection(ctx) +} + +// CanSend reports whether transport is ready for sending. +func (p *providerTransport) CanSend() bool { + return p.provider.CanSend() +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 0000000..012b3c5 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,63 @@ +// Package transport defines transport abstractions and registry. +package transport + +import ( + "context" + "errors" +) + +var ( + // ErrTransportNotFound is returned when a requested transport is not registered. + ErrTransportNotFound = errors.New("transport not found") +) + +// Transport defines a byte transport independent of the underlying carrier. +type Transport interface { + Connect(ctx context.Context) error + Send(data []byte) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool +} + +// Config holds common transport configuration. +type Config struct { + Carrier string + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int +} + +// Factory creates a transport instance. +type Factory func(ctx context.Context, cfg Config) (Transport, error) + +var registry = make(map[string]Factory) + +// Register adds a transport factory to the registry. +func Register(name string, factory Factory) { + registry[name] = factory +} + +// New creates a transport instance by name. +func New(ctx context.Context, name string, cfg Config) (Transport, error) { + factory, ok := registry[name] + if !ok { + return nil, ErrTransportNotFound + } + return factory(ctx, cfg) +} + +// Available returns a list of registered transport names. +func Available() []string { + names := make([]string, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} diff --git a/mobile/mobile.go b/mobile/mobile.go index 25d2e27..3032534 100644 --- a/mobile/mobile.go +++ b/mobile/mobile.go @@ -109,6 +109,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er err := client.RunWithReady( ctx, + "datachannel", "telemost", roomURL, keyHex,