From 9bd9503daa329e8a9a15616ca7f4d755e65e3944 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Mon, 20 Apr 2026 20:13:49 +0300 Subject: [PATCH] refactor: add direct link layer --- internal/app/session/session.go | 3 ++ internal/client/client.go | 41 ++++++++++++------------ internal/link/direct/direct.go | 43 ++++++++++++++++++++++++++ internal/link/link.go | 55 +++++++++++++++++++++++++++++++++ internal/server/server.go | 41 ++++++++++++------------ 5 files changed, 143 insertions(+), 40 deletions(-) create mode 100644 internal/link/direct/direct.go create mode 100644 internal/link/link.go diff --git a/internal/app/session/session.go b/internal/app/session/session.go index 5685d98..9127d26 100644 --- a/internal/app/session/session.go +++ b/internal/app/session/session.go @@ -8,6 +8,8 @@ import ( "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/link" + "github.com/openlibrecommunity/olcrtc/internal/link/direct" "github.com/openlibrecommunity/olcrtc/internal/provider/jazz" "github.com/openlibrecommunity/olcrtc/internal/provider/telemost" "github.com/openlibrecommunity/olcrtc/internal/provider/wbstream" @@ -49,6 +51,7 @@ func RegisterDefaults() { carrier.Register("telemost", telemost.New) carrier.Register("wb_stream", wbstream.New) + link.Register("direct", direct.New) transport.Register("datachannel", datachannel.New) } diff --git a/internal/client/client.go b/internal/client/client.go index 8d57e4a..0aa38a2 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -16,10 +16,10 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -43,7 +43,7 @@ var ( // Client handles local SOCKS5 connections and tunnels them via WebRTC. type Client struct { - transports []transport.Transport + links []link.Link cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn @@ -100,7 +100,7 @@ func RunWithReady( c := &Client{ cipher: cipher, connections: make(map[uint16]net.Conn), - transports: make([]transport.Transport, 0), + links: make([]link.Link, 0), clientID: clientID, dnsServer: dnsServer, } @@ -161,8 +161,8 @@ func (c *Client) setupMux() { c.mux = mux.New(c.clientID, func(frame []byte) error { for { canSend := true - for _, tr := range c.transports { - if !tr.CanSend() { + for _, ln := range c.links { + if !ln.CanSend() { canSend = false break } @@ -177,11 +177,11 @@ func (c *Client) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(c.transports) == 0 { + if len(c.links) == 0 { return ErrNoPeers } - idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec - return c.transports[idx].Send(encrypted) + idx := c.peerIdx.Add(1) % uint32(len(c.links)) //nolint:gosec + return c.links[idx].Send(encrypted) }) } @@ -196,7 +196,8 @@ func (c *Client) addTransport( socksProxyAddr string, socksProxyPort int, ) error { - tr, err := transport.New(ctx, transportName, transport.Config{ + ln, err := link.New(ctx, "direct", link.Config{ + Transport: transportName, Carrier: providerName, RoomURL: roomURL, Name: names.Generate(), @@ -206,21 +207,21 @@ func (c *Client) addTransport( ProxyPort: socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create transport: %w", err) + return fmt.Errorf("failed to create link: %w", err) } - tr.SetEndedCallback(func(reason string) { + ln.SetEndedCallback(func(reason string) { logger.Infof("Client transport %d reported conference end: %s", peerID, reason) cancel() }) - c.transports = append(c.transports, tr) + c.links = append(c.links, ln) - tr.SetReconnectCallback(func() { + ln.SetReconnectCallback(func() { c.handleTransportReconnect(peerID) }) logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) - if err := tr.Connect(ctx); err != nil { + if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect transport: %w", err) } logger.Infof("Transport %d connected", peerID) @@ -228,7 +229,7 @@ func (c *Client) addTransport( c.wg.Add(1) go func() { defer c.wg.Done() - tr.WatchConnection(ctx) + ln.WatchConnection(ctx) }() // Send initial reset to clean up any stale connections for this clientID on server @@ -256,11 +257,11 @@ func (c *Client) handleTransportReconnect(peerID int) { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(c.transports) == 0 { + if len(c.links) == 0 { return ErrNoPeers } - idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec - return c.transports[idx].Send(encrypted) + idx := c.peerIdx.Add(1) % uint32(len(c.links)) //nolint:gosec + return c.links[idx].Send(encrypted) }) c.mux.Reset() @@ -443,7 +444,7 @@ func (c *Client) shutdown() { } c.connMu.Unlock() - for i, tr := range c.transports { + for i, tr := range c.links { logger.Infof("closing transport %d", i) _ = tr.Close() } @@ -516,7 +517,7 @@ func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (c *Client) canSendData() bool { - for _, tr := range c.transports { + for _, tr := range c.links { if !tr.CanSend() { return false } diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go new file mode 100644 index 0000000..40b318b --- /dev/null +++ b/internal/link/direct/direct.go @@ -0,0 +1,43 @@ +// Package direct provides a pass-through link implementation above transports. +package direct + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/link" + "github.com/openlibrecommunity/olcrtc/internal/transport" +) + +type directLink struct { + transport transport.Transport +} + +// New creates a direct link that forwards bytes to the selected transport. +func New(ctx context.Context, cfg link.Config) (link.Link, error) { + tr, err := transport.New(ctx, cfg.Transport, transport.Config{ + Carrier: cfg.Carrier, + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("create transport for direct link: %w", err) + } + + return &directLink{transport: tr}, nil +} + +func (d *directLink) Connect(ctx context.Context) error { return d.transport.Connect(ctx) } +func (d *directLink) Send(data []byte) error { return d.transport.Send(data) } +func (d *directLink) Close() error { return d.transport.Close() } +func (d *directLink) SetReconnectCallback(cb func()) { d.transport.SetReconnectCallback(cb) } +func (d *directLink) SetShouldReconnect(fn func() bool) { d.transport.SetShouldReconnect(fn) } +func (d *directLink) SetEndedCallback(cb func(string)) { d.transport.SetEndedCallback(cb) } +func (d *directLink) WatchConnection(ctx context.Context) { + d.transport.WatchConnection(ctx) +} +func (d *directLink) CanSend() bool { return d.transport.CanSend() } diff --git a/internal/link/link.go b/internal/link/link.go new file mode 100644 index 0000000..7d671c6 --- /dev/null +++ b/internal/link/link.go @@ -0,0 +1,55 @@ +// Package link defines link-layer abstractions above transports. +package link + +import ( + "context" + "errors" +) + +var ( + // ErrLinkNotFound is returned when a requested link is not registered. + ErrLinkNotFound = errors.New("link not found") +) + +// Link defines a byte link above a transport. +type Link interface { + Connect(ctx context.Context) error + Send(data []byte) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool +} + +// Config holds common link configuration. +type Config struct { + Transport string + Carrier string + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int +} + +// Factory creates a link instance. +type Factory func(ctx context.Context, cfg Config) (Link, error) + +var registry = make(map[string]Factory) + +// Register adds a link factory to the registry. +func Register(name string, factory Factory) { + registry[name] = factory +} + +// New creates a link instance by name. +func New(ctx context.Context, name string, cfg Config) (Link, error) { + factory, ok := registry[name] + if !ok { + return nil, ErrLinkNotFound + } + return factory(ctx, cfg) +} diff --git a/internal/server/server.go b/internal/server/server.go index 0543052..e17c650 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,10 +17,10 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -42,7 +42,7 @@ var ( // Server handles incoming WebRTC connections and proxies their traffic. type Server struct { - transports []transport.Transport + links []link.Link cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn @@ -88,7 +88,7 @@ func Run( cipher: cipher, connections: make(map[uint16]net.Conn), streamPumps: make(map[uint16]net.Conn), - transports: make([]transport.Transport, 0), + links: make([]link.Link, 0), dnsServer: dnsServer, socksProxyAddr: socksProxyAddr, socksProxyPort: socksProxyPort, @@ -161,8 +161,8 @@ func (s *Server) setupMux() { s.mux = mux.New(0, func(frame []byte) error { for { canSend := true - for _, tr := range s.transports { - if !tr.CanSend() { + for _, ln := range s.links { + if !ln.CanSend() { canSend = false break } @@ -177,11 +177,11 @@ func (s *Server) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(s.transports) == 0 { + if len(s.links) == 0 { return ErrNoPeers } - idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec - return s.transports[idx].Send(encrypted) + idx := s.peerIdx.Add(1) % uint32(len(s.links)) //nolint:gosec + return s.links[idx].Send(encrypted) }) } @@ -193,7 +193,8 @@ func (s *Server) addTransport( peerID int, cancel context.CancelFunc, ) error { - tr, err := transport.New(ctx, transportName, transport.Config{ + ln, err := link.New(ctx, "direct", link.Config{ + Transport: transportName, Carrier: providerName, RoomURL: roomURL, Name: names.Generate(), @@ -203,21 +204,21 @@ func (s *Server) addTransport( ProxyPort: s.socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create transport: %w", err) + return fmt.Errorf("failed to create link: %w", err) } - tr.SetEndedCallback(func(reason string) { + ln.SetEndedCallback(func(reason string) { logger.Infof("Server transport %d reported conference end: %s", peerID, reason) cancel() }) - s.transports = append(s.transports, tr) + s.links = append(s.links, ln) - tr.SetReconnectCallback(func() { + ln.SetReconnectCallback(func() { s.handleTransportReconnect(peerID) }) logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) - if err := tr.Connect(ctx); err != nil { + if err := ln.Connect(ctx); err != nil { return fmt.Errorf("failed to connect transport: %w", err) } logger.Infof("Transport %d connected", peerID) @@ -225,7 +226,7 @@ func (s *Server) addTransport( s.wg.Add(1) go func() { defer s.wg.Done() - tr.WatchConnection(ctx) + ln.WatchConnection(ctx) }() return nil } @@ -247,11 +248,11 @@ func (s *Server) handleTransportReconnect(peerID int) { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(s.transports) == 0 { + if len(s.links) == 0 { return ErrNoPeers } - idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec - return s.transports[idx].Send(encrypted) + idx := s.peerIdx.Add(1) % uint32(len(s.links)) //nolint:gosec + return s.links[idx].Send(encrypted) }) s.mux.Reset() } @@ -349,7 +350,7 @@ func (s *Server) shutdown() { } s.connMu.Unlock() - for i, tr := range s.transports { + for i, tr := range s.links { logger.Infof("closing transport %d", i) _ = tr.Close() } @@ -561,7 +562,7 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (s *Server) canSendData() bool { - for _, tr := range s.transports { + for _, tr := range s.links { if !tr.CanSend() { return false }