refactor: introduce transport layer

This commit is contained in:
zarazaex69
2026-04-20 20:05:23 +03:00
parent ce87b017f1
commit fffb90e321
6 changed files with 269 additions and 109 deletions

View File

@@ -20,10 +20,13 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/provider/telemost"
"github.com/openlibrecommunity/olcrtc/internal/provider/wbstream"
"github.com/openlibrecommunity/olcrtc/internal/server"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/openlibrecommunity/olcrtc/internal/transport/datachannel"
)
type config struct {
mode string
transport string
roomID string
provider string
socksPort int
@@ -37,10 +40,11 @@ type config struct {
}
var (
errRoomIDRequired = errors.New("room ID required")
errModeRequired = errors.New("specify -mode srv or -mode cnc")
errProviderRequired = errors.New("provider required (use -provider telemost or -provider jazz)")
errUnsupportedProvider = errors.New("unsupported provider")
errRoomIDRequired = errors.New("room ID required")
errModeRequired = errors.New("specify -mode srv or -mode cnc")
errProviderRequired = errors.New("provider required (use -provider telemost or -provider jazz)")
errUnsupportedProvider = errors.New("unsupported provider")
errUnsupportedTransport = errors.New("unsupported transport")
)
func main() {
@@ -54,6 +58,7 @@ func run() error {
provider.Register("jazz", jazz.New)
provider.Register("telemost", telemost.New)
provider.Register("wb_stream", wbstream.New)
transport.Register("datachannel", datachannel.New)
cfg := parseFlags()
configureLogging(cfg.debug)
@@ -94,6 +99,7 @@ func parseFlags() config {
cfg := config{}
flag.StringVar(&cfg.mode, "mode", "", "Mode: srv or cnc")
flag.StringVar(&cfg.transport, "transport", "datachannel", "Transport: datachannel")
flag.StringVar(&cfg.roomID, "id", "", "Room ID")
flag.StringVar(&cfg.provider, "provider", "", "Provider: telemost or jazz (required)")
flag.IntVar(&cfg.socksPort, "socks-port", 1080, "SOCKS5 port (client only)")
@@ -116,20 +122,31 @@ func configureLogging(debug bool) {
}
func validateConfig(cfg config) error {
available := provider.Available()
availableProviders := provider.Available()
validProvider := false
for _, p := range available {
for _, p := range availableProviders {
if cfg.provider == p {
validProvider = true
break
}
}
availableTransports := transport.Available()
validTransport := false
for _, t := range availableTransports {
if cfg.transport == t {
validTransport = true
break
}
}
switch {
case cfg.provider == "":
return errProviderRequired
case !validProvider:
return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, available)
return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, availableProviders)
case !validTransport:
return fmt.Errorf("%w: %s (available: %v)", errUnsupportedTransport, cfg.transport, availableTransports)
case cfg.roomID == "" && cfg.provider != "jazz":
return errRoomIDRequired
case cfg.mode != "srv" && cfg.mode != "cnc":
@@ -169,6 +186,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
case "srv":
errCh <- server.Run(
ctx,
cfg.transport,
cfg.provider,
roomURL,
cfg.keyHex,
@@ -179,6 +197,7 @@ func runMode(ctx context.Context, cfg config, errCh chan<- error) {
case "cnc":
errCh <- client.Run(
ctx,
cfg.transport,
cfg.provider,
roomURL,
cfg.keyHex,

View File

@@ -19,8 +19,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
var (
@@ -44,21 +43,22 @@ var (
// Client handles local SOCKS5 connections and tunnels them via WebRTC.
type Client struct {
peers []provider.Provider
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
transports []transport.Transport
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
peerIdx atomic.Uint32
clientID uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
}
// Run starts the client with the specified parameters.
func Run(
ctx context.Context,
transportName,
providerName,
roomURL,
keyHex string,
@@ -67,12 +67,13 @@ func Run(
socksUser string,
socksPass string,
) error {
return RunWithReady(ctx, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
return RunWithReady(ctx, transportName, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil)
}
// RunWithReady is like Run but accepts a callback that is called when the client is ready.
func RunWithReady(
ctx context.Context,
transportName,
providerName,
roomURL,
keyHex string,
@@ -99,7 +100,7 @@ func RunWithReady(
c := &Client{
cipher: cipher,
connections: make(map[uint16]net.Conn),
peers: make([]provider.Provider, 0),
transports: make([]transport.Transport, 0),
clientID: clientID,
dnsServer: dnsServer,
}
@@ -108,8 +109,8 @@ func RunWithReady(
const peerCount = 1
for i := range peerCount {
if err := c.addPeer(runCtx, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
if err := c.addTransport(runCtx, transportName, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil {
return fmt.Errorf("addTransport failed: %w", err)
}
}
@@ -160,8 +161,8 @@ func (c *Client) setupMux() {
c.mux = mux.New(c.clientID, func(frame []byte) error {
for {
canSend := true
for _, peer := range c.peers {
if !peer.CanSend() {
for _, tr := range c.transports {
if !tr.CanSend() {
canSend = false
break
}
@@ -176,16 +177,17 @@ func (c *Client) setupMux() {
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
if len(c.transports) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec
return c.transports[idx].Send(encrypted)
})
}
func (c *Client) addPeer(
func (c *Client) addTransport(
ctx context.Context,
transportName,
providerName,
roomURL string,
peerID int,
@@ -194,7 +196,8 @@ func (c *Client) addPeer(
socksProxyAddr string,
socksProxyPort int,
) error {
peer, err := provider.New(ctx, providerName, provider.Config{
tr, err := transport.New(ctx, transportName, transport.Config{
Carrier: providerName,
RoomURL: roomURL,
Name: names.Generate(),
OnData: c.onData,
@@ -203,29 +206,29 @@ func (c *Client) addPeer(
ProxyPort: socksProxyPort,
})
if err != nil {
return fmt.Errorf("failed to create peer: %w", err)
return fmt.Errorf("failed to create transport: %w", err)
}
peer.SetEndedCallback(func(reason string) {
logger.Infof("Client peer %d reported conference end: %s", peerID, reason)
tr.SetEndedCallback(func(reason string) {
logger.Infof("Client transport %d reported conference end: %s", peerID, reason)
cancel()
})
c.peers = append(c.peers, peer)
c.transports = append(c.transports, tr)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
c.handlePeerReconnect(peerID, dc)
tr.SetReconnectCallback(func() {
c.handleTransportReconnect(peerID)
})
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName)
if err := tr.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect transport: %w", err)
}
logger.Infof("Peer %d connected", peerID)
logger.Infof("Transport %d connected", peerID)
c.wg.Add(1)
go func() {
defer c.wg.Done()
peer.WatchConnection(ctx)
tr.WatchConnection(ctx)
}()
// Send initial reset to clean up any stale connections for this clientID on server
@@ -236,8 +239,8 @@ func (c *Client) addPeer(
return nil
}
func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
func (c *Client) handleTransportReconnect(peerID int) {
logger.Infof("transport %d reconnect event", peerID)
c.connMu.Lock()
for sid, conn := range c.connections {
@@ -248,23 +251,21 @@ func (c *Client) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
}
c.connMu.Unlock()
if dc != nil {
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.peers) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.peers)) //nolint:gosec
return c.peers[idx].Send(encrypted)
})
c.mux.Reset()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after reconnect: %v", err)
c.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := c.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(c.transports) == 0 {
return ErrNoPeers
}
idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec
return c.transports[idx].Send(encrypted)
})
c.mux.Reset()
if err := c.mux.SendClientReset(); err != nil {
logger.Warnf("Failed to send client reset after reconnect: %v", err)
}
}
@@ -442,9 +443,9 @@ func (c *Client) shutdown() {
}
c.connMu.Unlock()
for i, peer := range c.peers {
logger.Infof("closing peer %d", i)
_ = peer.Close()
for i, tr := range c.transports {
logger.Infof("closing transport %d", i)
_ = tr.Close()
}
}
@@ -515,8 +516,8 @@ func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn)
}
func (c *Client) canSendData() bool {
for _, peer := range c.peers {
if !peer.CanSend() {
for _, tr := range c.transports {
if !tr.CanSend() {
return false
}
}

View File

@@ -20,8 +20,7 @@ import (
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/pion/webrtc/v4"
"github.com/openlibrecommunity/olcrtc/internal/transport"
)
var (
@@ -43,7 +42,7 @@ var (
// Server handles incoming WebRTC connections and proxies their traffic.
type Server struct {
peers []provider.Provider
transports []transport.Transport
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
@@ -69,6 +68,7 @@ type ConnectRequest struct {
// Run starts the server with the specified parameters.
func Run(
ctx context.Context,
transportName,
providerName,
roomURL,
keyHex string,
@@ -88,7 +88,7 @@ func Run(
cipher: cipher,
connections: make(map[uint16]net.Conn),
streamPumps: make(map[uint16]net.Conn),
peers: make([]provider.Provider, 0),
transports: make([]transport.Transport, 0),
dnsServer: dnsServer,
socksProxyAddr: socksProxyAddr,
socksProxyPort: socksProxyPort,
@@ -103,8 +103,8 @@ func Run(
const peerCount = 1
for i := range peerCount {
if err := s.addPeer(runCtx, providerName, roomURL, i, cancel); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
if err := s.addTransport(runCtx, transportName, providerName, roomURL, i, cancel); err != nil {
return fmt.Errorf("addTransport failed: %w", err)
}
}
@@ -161,8 +161,8 @@ func (s *Server) setupMux() {
s.mux = mux.New(0, func(frame []byte) error {
for {
canSend := true
for _, peer := range s.peers {
if !peer.CanSend() {
for _, tr := range s.transports {
if !tr.CanSend() {
canSend = false
break
}
@@ -177,22 +177,24 @@ func (s *Server) setupMux() {
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.peers) == 0 {
if len(s.transports) == 0 {
return ErrNoPeers
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
return s.peers[idx].Send(encrypted)
idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec
return s.transports[idx].Send(encrypted)
})
}
func (s *Server) addPeer(
func (s *Server) addTransport(
ctx context.Context,
transportName,
providerName,
roomURL string,
peerID int,
cancel context.CancelFunc,
) error {
peer, err := provider.New(ctx, providerName, provider.Config{
tr, err := transport.New(ctx, transportName, transport.Config{
Carrier: providerName,
RoomURL: roomURL,
Name: names.Generate(),
OnData: s.onData,
@@ -201,35 +203,35 @@ func (s *Server) addPeer(
ProxyPort: s.socksProxyPort,
})
if err != nil {
return fmt.Errorf("failed to create peer: %w", err)
return fmt.Errorf("failed to create transport: %w", err)
}
peer.SetEndedCallback(func(reason string) {
logger.Infof("Server peer %d reported conference end: %s", peerID, reason)
tr.SetEndedCallback(func(reason string) {
logger.Infof("Server transport %d reported conference end: %s", peerID, reason)
cancel()
})
s.peers = append(s.peers, peer)
s.transports = append(s.transports, tr)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
s.handlePeerReconnect(peerID, dc)
tr.SetReconnectCallback(func() {
s.handleTransportReconnect(peerID)
})
logger.Infof("Connecting peer %d to %s...", peerID, providerName)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName)
if err := tr.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect transport: %w", err)
}
logger.Infof("Peer %d connected", peerID)
logger.Infof("Transport %d connected", peerID)
s.wg.Add(1)
go func() {
defer s.wg.Done()
peer.WatchConnection(ctx)
tr.WatchConnection(ctx)
}()
return nil
}
func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
logger.Infof("peer %d reconnect event: dc=%v", peerID, dc != nil)
func (s *Server) handleTransportReconnect(peerID int) {
logger.Infof("transport %d reconnect event", peerID)
s.connMu.Lock()
for sid, conn := range s.connections {
@@ -240,20 +242,18 @@ func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
}
s.connMu.Unlock()
if dc != nil {
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.peers) == 0 {
return ErrNoPeers
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
return s.peers[idx].Send(encrypted)
})
s.mux.Reset()
}
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.transports) == 0 {
return ErrNoPeers
}
idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec
return s.transports[idx].Send(encrypted)
})
s.mux.Reset()
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
@@ -349,9 +349,9 @@ func (s *Server) shutdown() {
}
s.connMu.Unlock()
for i, peer := range s.peers {
logger.Infof("closing peer %d", i)
_ = peer.Close()
for i, tr := range s.transports {
logger.Infof("closing transport %d", i)
_ = tr.Close()
}
}
@@ -561,8 +561,8 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn)
}
func (s *Server) canSendData() bool {
for _, peer := range s.peers {
if !peer.CanSend() {
for _, tr := range s.transports {
if !tr.CanSend() {
return false
}
}

View File

@@ -0,0 +1,76 @@
// Package datachannel provides a transport backed by the current WebRTC providers.
package datachannel
import (
"context"
"fmt"
"github.com/openlibrecommunity/olcrtc/internal/provider"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/pion/webrtc/v4"
)
type providerTransport struct {
provider provider.Provider
}
// New creates a datachannel transport backed by a carrier-specific provider.
func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) {
p, err := provider.New(ctx, cfg.Carrier, provider.Config{
RoomURL: cfg.RoomURL,
Name: cfg.Name,
OnData: cfg.OnData,
DNSServer: cfg.DNSServer,
ProxyAddr: cfg.ProxyAddr,
ProxyPort: cfg.ProxyPort,
})
if err != nil {
return nil, fmt.Errorf("create provider transport: %w", err)
}
return &providerTransport{provider: p}, nil
}
// Connect starts the transport connection.
func (p *providerTransport) Connect(ctx context.Context) error {
return p.provider.Connect(ctx)
}
// Send transmits data through the transport.
func (p *providerTransport) Send(data []byte) error {
return p.provider.Send(data)
}
// Close terminates the transport.
func (p *providerTransport) Close() error {
return p.provider.Close()
}
// SetReconnectCallback registers reconnect handling.
func (p *providerTransport) SetReconnectCallback(cb func()) {
p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) {
if cb != nil {
cb()
}
})
}
// SetShouldReconnect configures reconnect policy.
func (p *providerTransport) SetShouldReconnect(fn func() bool) {
p.provider.SetShouldReconnect(fn)
}
// SetEndedCallback registers end-of-session handling.
func (p *providerTransport) SetEndedCallback(cb func(string)) {
p.provider.SetEndedCallback(cb)
}
// WatchConnection monitors connection lifecycle.
func (p *providerTransport) WatchConnection(ctx context.Context) {
p.provider.WatchConnection(ctx)
}
// CanSend reports whether transport is ready for sending.
func (p *providerTransport) CanSend() bool {
return p.provider.CanSend()
}

View File

@@ -0,0 +1,63 @@
// Package transport defines transport abstractions and registry.
package transport
import (
"context"
"errors"
)
var (
// ErrTransportNotFound is returned when a requested transport is not registered.
ErrTransportNotFound = errors.New("transport not found")
)
// Transport defines a byte transport independent of the underlying carrier.
type Transport interface {
Connect(ctx context.Context) error
Send(data []byte) error
Close() error
SetReconnectCallback(cb func())
SetShouldReconnect(fn func() bool)
SetEndedCallback(cb func(string))
WatchConnection(ctx context.Context)
CanSend() bool
}
// Config holds common transport configuration.
type Config struct {
Carrier string
RoomURL string
Name string
OnData func([]byte)
DNSServer string
ProxyAddr string
ProxyPort int
}
// Factory creates a transport instance.
type Factory func(ctx context.Context, cfg Config) (Transport, error)
var registry = make(map[string]Factory)
// Register adds a transport factory to the registry.
func Register(name string, factory Factory) {
registry[name] = factory
}
// New creates a transport instance by name.
func New(ctx context.Context, name string, cfg Config) (Transport, error) {
factory, ok := registry[name]
if !ok {
return nil, ErrTransportNotFound
}
return factory(ctx, cfg)
}
// Available returns a list of registered transport names.
func Available() []string {
names := make([]string, 0, len(registry))
for name := range registry {
names = append(names, name)
}
return names
}

View File

@@ -109,6 +109,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er
err := client.RunWithReady(
ctx,
"datachannel",
"telemost",
roomURL,
keyHex,