diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 3785c56..10b5068 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -3,7 +3,6 @@ package main import ( "context" - "errors" "flag" "fmt" "os" @@ -12,21 +11,16 @@ import ( "syscall" "time" - "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/app/session" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/provider" - "github.com/openlibrecommunity/olcrtc/internal/provider/jazz" - "github.com/openlibrecommunity/olcrtc/internal/provider/telemost" - "github.com/openlibrecommunity/olcrtc/internal/provider/wbstream" - "github.com/openlibrecommunity/olcrtc/internal/server" - "github.com/openlibrecommunity/olcrtc/internal/transport" - "github.com/openlibrecommunity/olcrtc/internal/transport/datachannel" ) type config struct { mode string + link string transport string + carrier string roomID string provider string socksPort int @@ -39,14 +33,6 @@ type config struct { socksProxyPort int } -var ( - errRoomIDRequired = errors.New("room ID required") - errModeRequired = errors.New("specify -mode srv or -mode cnc") - errProviderRequired = errors.New("provider required (use -provider telemost or -provider jazz)") - errUnsupportedProvider = errors.New("unsupported provider") - errUnsupportedTransport = errors.New("unsupported transport") -) - func main() { if err := run(); err != nil { logger.Error(err) @@ -55,15 +41,12 @@ func main() { } func run() error { - provider.Register("jazz", jazz.New) - provider.Register("telemost", telemost.New) - provider.Register("wb_stream", wbstream.New) - transport.Register("datachannel", datachannel.New) + session.RegisterDefaults() cfg := parseFlags() configureLogging(cfg.debug) - if err := validateConfig(cfg); err != nil { + if err := session.Validate(toSessionConfig(cfg)); err != nil { return err } @@ -83,7 +66,9 @@ func run() error { signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) errCh := make(chan error, 1) - go runMode(ctx, cfg, errCh) + go func() { + errCh <- session.Run(ctx, toSessionConfig(cfg)) + }() select { case <-sigCh: @@ -99,9 +84,11 @@ func parseFlags() config { cfg := config{} flag.StringVar(&cfg.mode, "mode", "", "Mode: srv or cnc") + flag.StringVar(&cfg.link, "link", "direct", "Link: direct") flag.StringVar(&cfg.transport, "transport", "datachannel", "Transport: datachannel") + flag.StringVar(&cfg.carrier, "carrier", "", "Carrier: telemost, jazz, wb_stream") flag.StringVar(&cfg.roomID, "id", "", "Room ID") - flag.StringVar(&cfg.provider, "provider", "", "Provider: telemost or jazz (required)") + flag.StringVar(&cfg.provider, "provider", "", "Deprecated alias for -carrier") flag.IntVar(&cfg.socksPort, "socks-port", 1080, "SOCKS5 port (client only)") flag.StringVar(&cfg.socksHost, "socks-host", "127.0.0.1", "SOCKS5 listen host (client only)") flag.StringVar(&cfg.keyHex, "key", "", "Shared encryption key (hex)") @@ -121,41 +108,6 @@ func configureLogging(debug bool) { } } -func validateConfig(cfg config) error { - availableProviders := provider.Available() - validProvider := false - for _, p := range availableProviders { - if cfg.provider == p { - validProvider = true - break - } - } - - availableTransports := transport.Available() - validTransport := false - for _, t := range availableTransports { - if cfg.transport == t { - validTransport = true - break - } - } - - switch { - case cfg.provider == "": - return errProviderRequired - case !validProvider: - return fmt.Errorf("%w: %s (available: %v)", errUnsupportedProvider, cfg.provider, availableProviders) - case !validTransport: - return fmt.Errorf("%w: %s (available: %v)", errUnsupportedTransport, cfg.transport, availableTransports) - case cfg.roomID == "" && cfg.provider != "jazz": - return errRoomIDRequired - case cfg.mode != "srv" && cfg.mode != "cnc": - return errModeRequired - default: - return nil - } -} - func resolveDataDir(dataDir string) (string, error) { if filepath.IsAbs(dataDir) { return dataDir, nil @@ -179,50 +131,29 @@ func loadNames(dataDir string) error { return nil } -func runMode(ctx context.Context, cfg config, errCh chan<- error) { - roomURL := buildRoomURL(cfg.provider, cfg.roomID) - - switch cfg.mode { - case "srv": - errCh <- server.Run( - ctx, - cfg.transport, - cfg.provider, - roomURL, - cfg.keyHex, - cfg.dnsServer, - cfg.socksProxyAddr, - cfg.socksProxyPort, - ) - case "cnc": - errCh <- client.Run( - ctx, - cfg.transport, - cfg.provider, - roomURL, - cfg.keyHex, - fmt.Sprintf("%s:%d", cfg.socksHost, cfg.socksPort), - cfg.dnsServer, - "", - "", - ) +func toSessionConfig(cfg config) session.Config { + return session.Config{ + Mode: cfg.mode, + Link: cfg.link, + Transport: cfg.transport, + Carrier: firstNonEmpty(cfg.carrier, cfg.provider), + RoomID: cfg.roomID, + KeyHex: cfg.keyHex, + SOCKSHost: cfg.socksHost, + SOCKSPort: cfg.socksPort, + DNSServer: cfg.dnsServer, + SOCKSProxyAddr: cfg.socksProxyAddr, + SOCKSProxyPort: cfg.socksProxyPort, } } -func buildRoomURL(providerName, roomID string) string { - switch providerName { - case "telemost": - return "https://telemost.yandex.ru/j/" + roomID - case "jazz": - if roomID == "" { - return "any" +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value } - return roomID - case "wb_stream": - return roomID - default: - return roomID } + return "" } func waitForShutdown(errCh <-chan error) error { diff --git a/internal/app/session/session.go b/internal/app/session/session.go new file mode 100644 index 0000000..3f3e4dd --- /dev/null +++ b/internal/app/session/session.go @@ -0,0 +1,152 @@ +// Package session wires runtime configuration to application mode entrypoints. +package session + +import ( + "context" + "errors" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/carrier/builtin" + "github.com/openlibrecommunity/olcrtc/internal/client" + "github.com/openlibrecommunity/olcrtc/internal/link" + "github.com/openlibrecommunity/olcrtc/internal/link/direct" + "github.com/openlibrecommunity/olcrtc/internal/server" + "github.com/openlibrecommunity/olcrtc/internal/transport" + "github.com/openlibrecommunity/olcrtc/internal/transport/datachannel" +) + +var ( + // ErrRoomIDRequired indicates that a room id is required for the selected carrier. + ErrRoomIDRequired = errors.New("room ID required") + // ErrModeRequired indicates that mode is not one of the supported values. + ErrModeRequired = errors.New("specify -mode srv or -mode cnc") + // ErrCarrierRequired indicates that no carrier was selected. + ErrCarrierRequired = errors.New("carrier required (use -carrier telemost or -carrier jazz)") + // ErrUnsupportedCarrier indicates that carrier is not registered. + ErrUnsupportedCarrier = errors.New("unsupported carrier") + // ErrUnsupportedLink indicates that link is not registered. + ErrUnsupportedLink = errors.New("unsupported link") + // ErrUnsupportedTransport indicates that transport is not registered. + ErrUnsupportedTransport = errors.New("unsupported transport") +) + +// Config holds runtime session settings. +type Config struct { + Mode string + Link string + Transport string + Carrier string + RoomID string + KeyHex string + SOCKSHost string + SOCKSPort int + DNSServer string + SOCKSProxyAddr string + SOCKSProxyPort int +} + +// RegisterDefaults registers built-in providers and transports. +func RegisterDefaults() { + builtin.Register() + link.Register("direct", direct.New) + transport.Register("datachannel", datachannel.New) +} + +// Validate verifies that the runtime config refers to registered components. +func Validate(cfg Config) error { + availableCarriers := carrier.Available() + validCarrier := false + for _, c := range availableCarriers { + if cfg.Carrier == c { + validCarrier = true + break + } + } + + availableTransports := transport.Available() + validTransport := false + for _, t := range availableTransports { + if cfg.Transport == t { + validTransport = true + break + } + } + + availableLinks := link.Available() + validLink := false + for _, l := range availableLinks { + if cfg.Link == l { + validLink = true + break + } + } + + switch { + case cfg.Carrier == "": + return ErrCarrierRequired + case !validCarrier: + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedCarrier, cfg.Carrier, availableCarriers) + case !validLink: + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedLink, cfg.Link, availableLinks) + case !validTransport: + return fmt.Errorf("%w: %s (available: %v)", ErrUnsupportedTransport, cfg.Transport, availableTransports) + case cfg.RoomID == "" && cfg.Carrier != "jazz": + return ErrRoomIDRequired + case cfg.Mode != "srv" && cfg.Mode != "cnc": + return ErrModeRequired + default: + return nil + } +} + +// Run starts the configured mode. +func Run(ctx context.Context, cfg Config) error { + roomURL := buildRoomURL(cfg.Carrier, cfg.RoomID) + + switch cfg.Mode { + case "srv": + return server.Run( + ctx, + cfg.Link, + cfg.Transport, + cfg.Carrier, + roomURL, + cfg.KeyHex, + cfg.DNSServer, + cfg.SOCKSProxyAddr, + cfg.SOCKSProxyPort, + ) + case "cnc": + return client.Run( + ctx, + cfg.Link, + cfg.Transport, + cfg.Carrier, + roomURL, + cfg.KeyHex, + fmt.Sprintf("%s:%d", cfg.SOCKSHost, cfg.SOCKSPort), + cfg.DNSServer, + "", + "", + ) + default: + return ErrModeRequired + } +} + +func buildRoomURL(carrierName, roomID string) string { + switch carrierName { + case "telemost": + return "https://telemost.yandex.ru/j/" + roomID + case "jazz": + if roomID == "" { + return "any" + } + return roomID + case "wb_stream": + return roomID + default: + return roomID + } +} diff --git a/internal/carrier/builtin/register.go b/internal/carrier/builtin/register.go new file mode 100644 index 0000000..dd57026 --- /dev/null +++ b/internal/carrier/builtin/register.go @@ -0,0 +1,16 @@ +// Package builtin registers the built-in carrier implementations. +package builtin + +import ( + "github.com/openlibrecommunity/olcrtc/internal/carrier" + "github.com/openlibrecommunity/olcrtc/internal/provider/jazz" + "github.com/openlibrecommunity/olcrtc/internal/provider/telemost" + "github.com/openlibrecommunity/olcrtc/internal/provider/wbstream" +) + +// Register wires the built-in legacy carriers into the carrier registry. +func Register() { + carrier.RegisterLegacy("jazz", jazz.New) + carrier.RegisterLegacy("telemost", telemost.New) + carrier.RegisterLegacy("wb_stream", wbstream.New) +} diff --git a/internal/carrier/bytestream.go b/internal/carrier/bytestream.go new file mode 100644 index 0000000..02a584d --- /dev/null +++ b/internal/carrier/bytestream.go @@ -0,0 +1,81 @@ +package carrier + +import ( + "context" + + "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/pion/webrtc/v4" +) + +// ByteStream is a carrier capability for bidirectional byte transport. +type ByteStream interface { + Connect(ctx context.Context) error + Send(data []byte) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool +} + +// VideoTrack is a carrier capability for publishing a local video track. +type VideoTrack interface { + AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) +} + +type legacySession struct { + provider provider.Provider +} + +// Capabilities reports the transport primitives supported by the legacy carrier. +func (s *legacySession) Capabilities() Capabilities { + caps := Capabilities{ByteStream: true} + _, caps.VideoTrack = s.provider.(provider.VideoTrackCapable) + return caps +} + +// OpenByteStream adapts the legacy provider to a generic byte stream capability. +func (s *legacySession) OpenByteStream() (ByteStream, error) { + return &legacyByteStream{provider: s.provider}, nil +} + +// OpenVideoTrack adapts a legacy provider to the generic video track capability. +func (s *legacySession) OpenVideoTrack() (VideoTrack, error) { + publisher, ok := s.provider.(provider.VideoTrackCapable) + if !ok { + return nil, ErrVideoTrackUnsupported + } + return &legacyVideoTrack{provider: publisher}, nil +} + +type legacyByteStream struct { + provider provider.Provider +} + +func (p *legacyByteStream) Connect(ctx context.Context) error { return p.provider.Connect(ctx) } +func (p *legacyByteStream) Send(data []byte) error { return p.provider.Send(data) } +func (p *legacyByteStream) Close() error { return p.provider.Close() } + +func (p *legacyByteStream) SetReconnectCallback(cb func()) { + p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) { + if cb != nil { + cb() + } + }) +} + +func (p *legacyByteStream) SetShouldReconnect(fn func() bool) { p.provider.SetShouldReconnect(fn) } +func (p *legacyByteStream) SetEndedCallback(cb func(string)) { p.provider.SetEndedCallback(cb) } +func (p *legacyByteStream) WatchConnection(ctx context.Context) { + p.provider.WatchConnection(ctx) +} +func (p *legacyByteStream) CanSend() bool { return p.provider.CanSend() } + +type legacyVideoTrack struct { + provider provider.VideoTrackCapable +} + +func (v *legacyVideoTrack) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { + return v.provider.AddVideoTrack(track) +} diff --git a/internal/carrier/carrier.go b/internal/carrier/carrier.go new file mode 100644 index 0000000..cbc8e38 --- /dev/null +++ b/internal/carrier/carrier.go @@ -0,0 +1,95 @@ +// Package carrier exposes carrier-oriented registration and construction APIs. +package carrier + +import ( + "context" + "errors" + + "github.com/openlibrecommunity/olcrtc/internal/provider" +) + +var ( + // ErrCarrierNotFound is returned when a requested carrier is not registered. + ErrCarrierNotFound = errors.New("carrier not found") + // ErrByteStreamUnsupported is returned when a carrier cannot provide a byte stream. + ErrByteStreamUnsupported = errors.New("carrier does not support byte stream") + // ErrVideoTrackUnsupported is returned when a carrier cannot publish video tracks. + ErrVideoTrackUnsupported = errors.New("carrier does not support video tracks") +) + +// Capabilities describes the transport primitives a carrier can expose. +type Capabilities struct { + ByteStream bool + VideoTrack bool +} + +// Session is the carrier-level runtime handle. +type Session interface { + Capabilities() Capabilities +} + +// ByteStreamCapable is implemented by carriers that can expose a byte stream. +type ByteStreamCapable interface { + OpenByteStream() (ByteStream, error) +} + +// VideoTrackCapable is implemented by carriers that can publish video tracks. +type VideoTrackCapable interface { + OpenVideoTrack() (VideoTrack, error) +} + +// Config holds carrier configuration. +type Config struct { + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int +} + +// Factory creates a new carrier session. +type Factory func(ctx context.Context, cfg Config) (Session, error) + +var registry = make(map[string]Factory) + +// Register adds a carrier factory to the registry. +func Register(name string, factory Factory) { + registry[name] = factory +} + +// RegisterLegacy adapts an existing provider factory into the carrier registry. +func RegisterLegacy(name string, factory provider.Factory) { + Register(name, func(ctx context.Context, cfg Config) (Session, error) { + legacy, err := factory(ctx, provider.Config{ + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, err + } + return &legacySession{provider: legacy}, nil + }) +} + +// New creates a carrier session by name. +func New(ctx context.Context, name string, cfg Config) (Session, error) { + factory, ok := registry[name] + if !ok { + return nil, ErrCarrierNotFound + } + return factory(ctx, cfg) +} + +// Available returns a list of registered carriers. +func Available() []string { + names := make([]string, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} diff --git a/internal/client/client.go b/internal/client/client.go index 8d57e4a..5e30e4e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -16,10 +16,10 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -29,8 +29,8 @@ var ( ErrKeyStringLength = errors.New("key string length must be 32") // ErrInvalidSocks5 is returned when the SOCKS version is not 5. ErrInvalidSocks5 = errors.New("invalid SOCKS5 version") - // ErrNoPeers is returned when no peers are available for sending. - ErrNoPeers = errors.New("no peers available") + // ErrNoLinks is returned when no links are available for sending. + ErrNoLinks = errors.New("no links available") // ErrEncryptFailed is returned when encryption fails. ErrEncryptFailed = errors.New("encrypt failed") // ErrUnsupportedSocksCommand is returned when a SOCKS5 command is not supported. @@ -41,14 +41,14 @@ var ( ErrTunnelSetupFailed = errors.New("tunnel setup failed") ) -// Client handles local SOCKS5 connections and tunnels them via WebRTC. +// Client handles local SOCKS5 connections and tunnels them through the selected runtime stack. type Client struct { - transports []transport.Transport + links []link.Link cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn connMu sync.RWMutex - peerIdx atomic.Uint32 + linkIdx atomic.Uint32 clientID uint32 activeClients atomic.Int32 wg sync.WaitGroup @@ -58,8 +58,9 @@ type Client struct { // Run starts the client with the specified parameters. func Run( ctx context.Context, + linkName, transportName, - providerName, + carrierName, roomURL, keyHex string, localAddr string, @@ -67,14 +68,15 @@ func Run( socksUser string, socksPass string, ) error { - return RunWithReady(ctx, transportName, providerName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil) + return RunWithReady(ctx, linkName, transportName, carrierName, roomURL, keyHex, localAddr, dnsServer, socksUser, socksPass, nil) } // RunWithReady is like Run but accepts a callback that is called when the client is ready. func RunWithReady( ctx context.Context, + linkName, transportName, - providerName, + carrierName, roomURL, keyHex string, localAddr string, @@ -100,17 +102,17 @@ func RunWithReady( c := &Client{ cipher: cipher, connections: make(map[uint16]net.Conn), - transports: make([]transport.Transport, 0), + links: make([]link.Link, 0), clientID: clientID, dnsServer: dnsServer, } c.setupMux() - const peerCount = 1 - for i := range peerCount { - if err := c.addTransport(runCtx, transportName, providerName, roomURL, i, cancel, dnsServer, "", 0); err != nil { - return fmt.Errorf("addTransport failed: %w", err) + const linkCount = 1 + for i := range linkCount { + if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0); err != nil { + return fmt.Errorf("addLink failed: %w", err) } } @@ -161,8 +163,8 @@ func (c *Client) setupMux() { c.mux = mux.New(c.clientID, func(frame []byte) error { for { canSend := true - for _, tr := range c.transports { - if !tr.CanSend() { + for _, ln := range c.links { + if !ln.CanSend() { canSend = false break } @@ -177,27 +179,29 @@ func (c *Client) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(c.transports) == 0 { - return ErrNoPeers + if len(c.links) == 0 { + return ErrNoLinks } - idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec - return c.transports[idx].Send(encrypted) + idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec + return c.links[idx].Send(encrypted) }) } -func (c *Client) addTransport( +func (c *Client) addLink( ctx context.Context, + linkName, transportName, - providerName, + carrierName, roomURL string, - peerID int, + linkID int, cancel context.CancelFunc, dnsServer, socksProxyAddr string, socksProxyPort int, ) error { - tr, err := transport.New(ctx, transportName, transport.Config{ - Carrier: providerName, + ln, err := link.New(ctx, linkName, link.Config{ + Transport: transportName, + Carrier: carrierName, RoomURL: roomURL, Name: names.Generate(), OnData: c.onData, @@ -206,29 +210,29 @@ func (c *Client) addTransport( ProxyPort: socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create transport: %w", err) + return fmt.Errorf("failed to create link: %w", err) } - tr.SetEndedCallback(func(reason string) { - logger.Infof("Client transport %d reported conference end: %s", peerID, reason) + ln.SetEndedCallback(func(reason string) { + logger.Infof("Client link %d reported conference end: %s", linkID, reason) cancel() }) - c.transports = append(c.transports, tr) + c.links = append(c.links, ln) - tr.SetReconnectCallback(func() { - c.handleTransportReconnect(peerID) + ln.SetReconnectCallback(func() { + c.handleLinkReconnect(linkID) }) - logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) - if err := tr.Connect(ctx); err != nil { - return fmt.Errorf("failed to connect transport: %w", err) + logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName) + if err := ln.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect link: %w", err) } - logger.Infof("Transport %d connected", peerID) + logger.Infof("Link %d connected", linkID) c.wg.Add(1) go func() { defer c.wg.Done() - tr.WatchConnection(ctx) + ln.WatchConnection(ctx) }() // Send initial reset to clean up any stale connections for this clientID on server @@ -239,8 +243,8 @@ func (c *Client) addTransport( return nil } -func (c *Client) handleTransportReconnect(peerID int) { - logger.Infof("transport %d reconnect event", peerID) +func (c *Client) handleLinkReconnect(linkID int) { + logger.Infof("link %d reconnect event", linkID) c.connMu.Lock() for sid, conn := range c.connections { @@ -256,11 +260,11 @@ func (c *Client) handleTransportReconnect(peerID int) { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(c.transports) == 0 { - return ErrNoPeers + if len(c.links) == 0 { + return ErrNoLinks } - idx := c.peerIdx.Add(1) % uint32(len(c.transports)) //nolint:gosec - return c.transports[idx].Send(encrypted) + idx := c.linkIdx.Add(1) % uint32(len(c.links)) //nolint:gosec + return c.links[idx].Send(encrypted) }) c.mux.Reset() @@ -443,8 +447,8 @@ func (c *Client) shutdown() { } c.connMu.Unlock() - for i, tr := range c.transports { - logger.Infof("closing transport %d", i) + for i, tr := range c.links { + logger.Infof("closing link %d", i) _ = tr.Close() } } @@ -516,7 +520,7 @@ func (c *Client) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (c *Client) canSendData() bool { - for _, tr := range c.transports { + for _, tr := range c.links { if !tr.CanSend() { return false } diff --git a/internal/link/direct/direct.go b/internal/link/direct/direct.go new file mode 100644 index 0000000..40b318b --- /dev/null +++ b/internal/link/direct/direct.go @@ -0,0 +1,43 @@ +// Package direct provides a pass-through link implementation above transports. +package direct + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/link" + "github.com/openlibrecommunity/olcrtc/internal/transport" +) + +type directLink struct { + transport transport.Transport +} + +// New creates a direct link that forwards bytes to the selected transport. +func New(ctx context.Context, cfg link.Config) (link.Link, error) { + tr, err := transport.New(ctx, cfg.Transport, transport.Config{ + Carrier: cfg.Carrier, + RoomURL: cfg.RoomURL, + Name: cfg.Name, + OnData: cfg.OnData, + DNSServer: cfg.DNSServer, + ProxyAddr: cfg.ProxyAddr, + ProxyPort: cfg.ProxyPort, + }) + if err != nil { + return nil, fmt.Errorf("create transport for direct link: %w", err) + } + + return &directLink{transport: tr}, nil +} + +func (d *directLink) Connect(ctx context.Context) error { return d.transport.Connect(ctx) } +func (d *directLink) Send(data []byte) error { return d.transport.Send(data) } +func (d *directLink) Close() error { return d.transport.Close() } +func (d *directLink) SetReconnectCallback(cb func()) { d.transport.SetReconnectCallback(cb) } +func (d *directLink) SetShouldReconnect(fn func() bool) { d.transport.SetShouldReconnect(fn) } +func (d *directLink) SetEndedCallback(cb func(string)) { d.transport.SetEndedCallback(cb) } +func (d *directLink) WatchConnection(ctx context.Context) { + d.transport.WatchConnection(ctx) +} +func (d *directLink) CanSend() bool { return d.transport.CanSend() } diff --git a/internal/link/link.go b/internal/link/link.go new file mode 100644 index 0000000..bb86890 --- /dev/null +++ b/internal/link/link.go @@ -0,0 +1,64 @@ +// Package link defines link-layer abstractions above transports. +package link + +import ( + "context" + "errors" +) + +var ( + // ErrLinkNotFound is returned when a requested link is not registered. + ErrLinkNotFound = errors.New("link not found") +) + +// Link defines a byte link above a transport. +type Link interface { + Connect(ctx context.Context) error + Send(data []byte) error + Close() error + SetReconnectCallback(cb func()) + SetShouldReconnect(fn func() bool) + SetEndedCallback(cb func(string)) + WatchConnection(ctx context.Context) + CanSend() bool +} + +// Config holds common link configuration. +type Config struct { + Transport string + Carrier string + RoomURL string + Name string + OnData func([]byte) + DNSServer string + ProxyAddr string + ProxyPort int +} + +// Factory creates a link instance. +type Factory func(ctx context.Context, cfg Config) (Link, error) + +var registry = make(map[string]Factory) + +// Register adds a link factory to the registry. +func Register(name string, factory Factory) { + registry[name] = factory +} + +// New creates a link instance by name. +func New(ctx context.Context, name string, cfg Config) (Link, error) { + factory, ok := registry[name] + if !ok { + return nil, ErrLinkNotFound + } + return factory(ctx, cfg) +} + +// Available returns a list of registered link names. +func Available() []string { + names := make([]string, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 09265ed..bbc7497 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -35,8 +35,10 @@ type Provider interface { CanSend() bool GetSendQueue() chan []byte GetBufferedAmount() uint64 +} - // AddVideoTrack adds a video track to the connection. +// VideoTrackCapable is implemented by providers that can publish video tracks. +type VideoTrackCapable interface { AddVideoTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) } diff --git a/internal/server/server.go b/internal/server/server.go index 0543052..6a8e25c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,10 +17,10 @@ import ( "time" "github.com/openlibrecommunity/olcrtc/internal/crypto" + "github.com/openlibrecommunity/olcrtc/internal/link" "github.com/openlibrecommunity/olcrtc/internal/logger" "github.com/openlibrecommunity/olcrtc/internal/mux" "github.com/openlibrecommunity/olcrtc/internal/names" - "github.com/openlibrecommunity/olcrtc/internal/transport" ) var ( @@ -32,24 +32,24 @@ var ( ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed") // ErrSocks5ConnectFailed is returned when SOCKS5 connection fails. ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") - // ErrNoPeers is returned when no peers are available. - ErrNoPeers = errors.New("no peers available") + // ErrNoLinks is returned when no links are available. + ErrNoLinks = errors.New("no links available") // ErrDialProxy is returned when dialing the proxy fails. ErrDialProxy = errors.New("failed to dial proxy") // ErrEncryptFailed is returned when encryption fails. ErrEncryptFailed = errors.New("encrypt failed") ) -// Server handles incoming WebRTC connections and proxies their traffic. +// Server handles incoming tunnel connections and proxies their traffic. type Server struct { - transports []transport.Transport + links []link.Link cipher *crypto.Cipher mux *mux.Multiplexer connections map[uint16]net.Conn connMu sync.RWMutex streamPumps map[uint16]net.Conn pumpMu sync.Mutex - peerIdx atomic.Uint32 + linkIdx atomic.Uint32 activeClients atomic.Int32 wg sync.WaitGroup dnsServer string @@ -68,8 +68,9 @@ type ConnectRequest struct { // Run starts the server with the specified parameters. func Run( ctx context.Context, + linkName, transportName, - providerName, + carrierName, roomURL, keyHex string, dnsServer, @@ -88,7 +89,7 @@ func Run( cipher: cipher, connections: make(map[uint16]net.Conn), streamPumps: make(map[uint16]net.Conn), - transports: make([]transport.Transport, 0), + links: make([]link.Link, 0), dnsServer: dnsServer, socksProxyAddr: socksProxyAddr, socksProxyPort: socksProxyPort, @@ -101,10 +102,10 @@ func Run( s.setupResolver() s.setupMux() - const peerCount = 1 - for i := range peerCount { - if err := s.addTransport(runCtx, transportName, providerName, roomURL, i, cancel); err != nil { - return fmt.Errorf("addTransport failed: %w", err) + const linkCount = 1 + for i := range linkCount { + if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel); err != nil { + return fmt.Errorf("addLink failed: %w", err) } } @@ -161,8 +162,8 @@ func (s *Server) setupMux() { s.mux = mux.New(0, func(frame []byte) error { for { canSend := true - for _, tr := range s.transports { - if !tr.CanSend() { + for _, ln := range s.links { + if !ln.CanSend() { canSend = false break } @@ -177,24 +178,26 @@ func (s *Server) setupMux() { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(s.transports) == 0 { - return ErrNoPeers + if len(s.links) == 0 { + return ErrNoLinks } - idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec - return s.transports[idx].Send(encrypted) + idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec + return s.links[idx].Send(encrypted) }) } -func (s *Server) addTransport( +func (s *Server) addLink( ctx context.Context, + linkName, transportName, - providerName, + carrierName, roomURL string, - peerID int, + linkID int, cancel context.CancelFunc, ) error { - tr, err := transport.New(ctx, transportName, transport.Config{ - Carrier: providerName, + ln, err := link.New(ctx, linkName, link.Config{ + Transport: transportName, + Carrier: carrierName, RoomURL: roomURL, Name: names.Generate(), OnData: s.onData, @@ -203,35 +206,35 @@ func (s *Server) addTransport( ProxyPort: s.socksProxyPort, }) if err != nil { - return fmt.Errorf("failed to create transport: %w", err) + return fmt.Errorf("failed to create link: %w", err) } - tr.SetEndedCallback(func(reason string) { - logger.Infof("Server transport %d reported conference end: %s", peerID, reason) + ln.SetEndedCallback(func(reason string) { + logger.Infof("Server link %d reported conference end: %s", linkID, reason) cancel() }) - s.transports = append(s.transports, tr) + s.links = append(s.links, ln) - tr.SetReconnectCallback(func() { - s.handleTransportReconnect(peerID) + ln.SetReconnectCallback(func() { + s.handleLinkReconnect(linkID) }) - logger.Infof("Connecting transport %d via %s/%s...", peerID, transportName, providerName) - if err := tr.Connect(ctx); err != nil { - return fmt.Errorf("failed to connect transport: %w", err) + logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName) + if err := ln.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect link: %w", err) } - logger.Infof("Transport %d connected", peerID) + logger.Infof("Link %d connected", linkID) s.wg.Add(1) go func() { defer s.wg.Done() - tr.WatchConnection(ctx) + ln.WatchConnection(ctx) }() return nil } -func (s *Server) handleTransportReconnect(peerID int) { - logger.Infof("transport %d reconnect event", peerID) +func (s *Server) handleLinkReconnect(linkID int) { + logger.Infof("link %d reconnect event", linkID) s.connMu.Lock() for sid, conn := range s.connections { @@ -247,11 +250,11 @@ func (s *Server) handleTransportReconnect(peerID int) { if err != nil { return fmt.Errorf("%w: %w", ErrEncryptFailed, err) } - if len(s.transports) == 0 { - return ErrNoPeers + if len(s.links) == 0 { + return ErrNoLinks } - idx := s.peerIdx.Add(1) % uint32(len(s.transports)) //nolint:gosec - return s.transports[idx].Send(encrypted) + idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec + return s.links[idx].Send(encrypted) }) s.mux.Reset() } @@ -349,8 +352,8 @@ func (s *Server) shutdown() { } s.connMu.Unlock() - for i, tr := range s.transports { - logger.Infof("closing transport %d", i) + for i, tr := range s.links { + logger.Infof("closing link %d", i) _ = tr.Close() } } @@ -561,7 +564,7 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) } func (s *Server) canSendData() bool { - for _, tr := range s.transports { + for _, tr := range s.links { if !tr.CanSend() { return false } diff --git a/internal/transport/datachannel/transport.go b/internal/transport/datachannel/transport.go index dfd1b2d..8b61848 100644 --- a/internal/transport/datachannel/transport.go +++ b/internal/transport/datachannel/transport.go @@ -5,18 +5,19 @@ import ( "context" "fmt" - "github.com/openlibrecommunity/olcrtc/internal/provider" + "github.com/openlibrecommunity/olcrtc/internal/carrier" "github.com/openlibrecommunity/olcrtc/internal/transport" - "github.com/pion/webrtc/v4" ) -type providerTransport struct { - provider provider.Provider +const defaultMaxPayloadSize = 12 * 1024 + +type streamTransport struct { + stream carrier.ByteStream } // New creates a datachannel transport backed by a carrier-specific provider. func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) { - p, err := provider.New(ctx, cfg.Carrier, provider.Config{ + session, err := carrier.New(ctx, cfg.Carrier, carrier.Config{ RoomURL: cfg.RoomURL, Name: cfg.Name, OnData: cfg.OnData, @@ -28,49 +29,65 @@ func New(ctx context.Context, cfg transport.Config) (transport.Transport, error) return nil, fmt.Errorf("create provider transport: %w", err) } - return &providerTransport{provider: p}, nil + streamCapable, ok := session.(carrier.ByteStreamCapable) + if !ok { + return nil, carrier.ErrByteStreamUnsupported + } + + stream, err := streamCapable.OpenByteStream() + if err != nil { + return nil, fmt.Errorf("open byte stream: %w", err) + } + + return &streamTransport{stream: stream}, nil } // Connect starts the transport connection. -func (p *providerTransport) Connect(ctx context.Context) error { - return p.provider.Connect(ctx) +func (p *streamTransport) Connect(ctx context.Context) error { + return p.stream.Connect(ctx) } // Send transmits data through the transport. -func (p *providerTransport) Send(data []byte) error { - return p.provider.Send(data) +func (p *streamTransport) Send(data []byte) error { + return p.stream.Send(data) } // Close terminates the transport. -func (p *providerTransport) Close() error { - return p.provider.Close() +func (p *streamTransport) Close() error { + return p.stream.Close() } // SetReconnectCallback registers reconnect handling. -func (p *providerTransport) SetReconnectCallback(cb func()) { - p.provider.SetReconnectCallback(func(_ *webrtc.DataChannel) { - if cb != nil { - cb() - } - }) +func (p *streamTransport) SetReconnectCallback(cb func()) { + p.stream.SetReconnectCallback(cb) } // SetShouldReconnect configures reconnect policy. -func (p *providerTransport) SetShouldReconnect(fn func() bool) { - p.provider.SetShouldReconnect(fn) +func (p *streamTransport) SetShouldReconnect(fn func() bool) { + p.stream.SetShouldReconnect(fn) } // SetEndedCallback registers end-of-session handling. -func (p *providerTransport) SetEndedCallback(cb func(string)) { - p.provider.SetEndedCallback(cb) +func (p *streamTransport) SetEndedCallback(cb func(string)) { + p.stream.SetEndedCallback(cb) } // WatchConnection monitors connection lifecycle. -func (p *providerTransport) WatchConnection(ctx context.Context) { - p.provider.WatchConnection(ctx) +func (p *streamTransport) WatchConnection(ctx context.Context) { + p.stream.WatchConnection(ctx) } // CanSend reports whether transport is ready for sending. -func (p *providerTransport) CanSend() bool { - return p.provider.CanSend() +func (p *streamTransport) CanSend() bool { + return p.stream.CanSend() +} + +// Features describes the current datachannel transport semantics. +func (p *streamTransport) Features() transport.Features { + return transport.Features{ + Reliable: true, + Ordered: true, + MessageOriented: true, + MaxPayloadSize: defaultMaxPayloadSize, + } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 012b3c5..74c8c4c 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -11,6 +11,14 @@ var ( ErrTransportNotFound = errors.New("transport not found") ) +// Features describes the delivery semantics of a transport. +type Features struct { + Reliable bool + Ordered bool + MessageOriented bool + MaxPayloadSize int +} + // Transport defines a byte transport independent of the underlying carrier. type Transport interface { Connect(ctx context.Context) error @@ -21,6 +29,7 @@ type Transport interface { SetEndedCallback(cb func(string)) WatchConnection(ctx context.Context) CanSend() bool + Features() Features } // Config holds common transport configuration. diff --git a/mobile/mobile.go b/mobile/mobile.go index 3032534..9fc0bd9 100644 --- a/mobile/mobile.go +++ b/mobile/mobile.go @@ -109,6 +109,7 @@ func Start(roomID, keyHex string, socksPort int, socksUser, socksPass string) er err := client.RunWithReady( ctx, + "direct", "datachannel", "telemost", roomURL,