mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
feat: add timeout to openControlStream function
This commit is contained in:
@@ -240,12 +240,21 @@ func openControlStream(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
claims map[string]any,
|
||||
) (*smux.Stream, string, error) {
|
||||
return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout)
|
||||
}
|
||||
|
||||
func openControlStreamTimeout(
|
||||
sess *smux.Session,
|
||||
deviceID string,
|
||||
claims map[string]any,
|
||||
timeout time.Duration,
|
||||
) (*smux.Stream, string, error) {
|
||||
stream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("open control stream: %w", err)
|
||||
}
|
||||
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
|
||||
_ = stream.SetDeadline(time.Now().Add(timeout))
|
||||
sid, err := handshake.Client(stream, deviceID, claims)
|
||||
_ = stream.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
@@ -303,32 +312,71 @@ func smuxConfig() *smux.Config {
|
||||
|
||||
func (c *Client) handleReconnect() {
|
||||
logger.Infof("client link reconnect - tearing down smux session")
|
||||
|
||||
// Install a fresh muxconn immediately so onData never hits nil while
|
||||
// the old session is being torn down. tryReopenSession will swap it
|
||||
// again with its own conn on each attempt.
|
||||
newConn := muxconn.New(c.ln, c.cipher)
|
||||
|
||||
c.sessMu.Lock()
|
||||
if c.controlStrm != nil {
|
||||
_ = c.controlStrm.Close()
|
||||
c.controlStrm = nil
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
c.session = nil
|
||||
}
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
oldControl := c.controlStrm
|
||||
oldSess := c.session
|
||||
oldConn := c.conn
|
||||
c.conn = newConn
|
||||
c.session = nil
|
||||
c.controlStrm = nil
|
||||
c.sessionID = ""
|
||||
c.sessMu.Unlock()
|
||||
c.conn = muxconn.New(c.ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
if err != nil {
|
||||
logger.Warnf("smux re-init failed: %v", err)
|
||||
return
|
||||
|
||||
if oldControl != nil {
|
||||
_ = oldControl.Close()
|
||||
}
|
||||
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
|
||||
if oldSess != nil {
|
||||
_ = oldSess.Close()
|
||||
}
|
||||
if oldConn != nil {
|
||||
_ = oldConn.Close()
|
||||
}
|
||||
|
||||
// Server-side may still be tearing down its own session when our callback
|
||||
// fires — carriers don't guarantee reconnect callbacks are delivered to both
|
||||
// peers atomically. Retry the handshake a few times, building a fresh
|
||||
// muxconn+smux pair on each attempt so a failed smux.Close doesn't corrupt
|
||||
// the byte stream for subsequent attempts.
|
||||
const (
|
||||
maxAttempts = 5
|
||||
attemptDelay = 300 * time.Millisecond
|
||||
)
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if c.tryReopenSession(attempt) {
|
||||
return
|
||||
}
|
||||
time.Sleep(attemptDelay)
|
||||
}
|
||||
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
|
||||
}
|
||||
|
||||
func (c *Client) tryReopenSession(attempt int) bool {
|
||||
conn := muxconn.New(c.ln, c.cipher)
|
||||
|
||||
c.sessMu.Lock()
|
||||
old := c.conn
|
||||
c.conn = conn
|
||||
c.sessMu.Unlock()
|
||||
if old != nil {
|
||||
_ = old.Close()
|
||||
}
|
||||
|
||||
sess, err := smux.Client(conn, smuxConfig())
|
||||
if err != nil {
|
||||
logger.Warnf("handshake on reconnect failed: %v", err)
|
||||
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
|
||||
return false
|
||||
}
|
||||
control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err)
|
||||
_ = sess.Close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
|
||||
c.sessMu.Lock()
|
||||
@@ -336,6 +384,7 @@ func (c *Client) handleReconnect() {
|
||||
c.controlStrm = control
|
||||
c.sessionID = sid
|
||||
c.sessMu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
|
||||
@@ -131,9 +131,15 @@ func (r *memoryRoom) triggerReconnect() {
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, stream := range streams {
|
||||
stream.triggerReconnect()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream.triggerReconnect()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *memoryRoom) triggerEnded(reason string) {
|
||||
|
||||
@@ -50,6 +50,19 @@ func New(ln link.Link, cipher *crypto.Cipher) *Conn {
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset clears any buffered inbound bytes, re-arms a closed conn for writes,
|
||||
// and unblocks pending Reads so the smux session on top of it exits cleanly.
|
||||
// Use it when the link stays up but the peer's smux session has been rebuilt:
|
||||
// the inbound byte stream (now indistinguishable random-looking data) must be
|
||||
// parsed by the fresh smux state, not the old one.
|
||||
func (c *Conn) Reset() {
|
||||
c.mu.Lock()
|
||||
c.buf = nil
|
||||
c.closed = false
|
||||
c.cond.Broadcast()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Push hands an encrypted wire payload (one OnData event) to the conn.
|
||||
func (c *Conn) Push(ciphertext []byte) {
|
||||
pt, err := c.cipher.Decrypt(ciphertext)
|
||||
|
||||
@@ -36,6 +36,19 @@ var (
|
||||
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
|
||||
)
|
||||
|
||||
// SessionOpenFunc is called after a successful handshake, before the server
|
||||
// accepts tunnel streams on that session.
|
||||
type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any)
|
||||
|
||||
// SessionCloseFunc is called when a session is torn down. Possible reasons:
|
||||
// "reconnect" (carrier dropped and was reestablished), "closed" (graceful
|
||||
// shutdown or ctx cancel).
|
||||
type SessionCloseFunc func(sessionID, reason string)
|
||||
|
||||
// TrafficFunc is called once per tunnel stream, after the copy loops finish.
|
||||
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
|
||||
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
|
||||
|
||||
// Server handles incoming tunnel connections and proxies their traffic.
|
||||
type Server struct {
|
||||
ln link.Link
|
||||
@@ -46,6 +59,9 @@ type Server struct {
|
||||
reinstallMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
authHook handshake.AuthFunc
|
||||
onOpen SessionOpenFunc
|
||||
onClose SessionCloseFunc
|
||||
onTraffic TrafficFunc
|
||||
deviceID string
|
||||
sessionID string
|
||||
dnsServer string
|
||||
@@ -94,6 +110,13 @@ type Config struct {
|
||||
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
||||
// return a session ID. If nil, every client is admitted with a random UUID.
|
||||
AuthHook handshake.AuthFunc
|
||||
|
||||
// OnSessionOpen fires after a successful handshake. Nil means no-op.
|
||||
OnSessionOpen SessionOpenFunc
|
||||
// OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op.
|
||||
OnSessionClose SessionCloseFunc
|
||||
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
|
||||
OnTraffic TrafficFunc
|
||||
}
|
||||
|
||||
// Run starts the server with the given configuration.
|
||||
@@ -110,10 +133,25 @@ func Run(ctx context.Context, cfg Config) error {
|
||||
if hook == nil {
|
||||
hook = defaultAuthHook
|
||||
}
|
||||
onOpen := cfg.OnSessionOpen
|
||||
if onOpen == nil {
|
||||
onOpen = func(string, string, map[string]any) {}
|
||||
}
|
||||
onClose := cfg.OnSessionClose
|
||||
if onClose == nil {
|
||||
onClose = func(string, string) {}
|
||||
}
|
||||
onTraffic := cfg.OnTraffic
|
||||
if onTraffic == nil {
|
||||
onTraffic = func(string, string, uint64, uint64) {}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
authHook: hook,
|
||||
onOpen: onOpen,
|
||||
onClose: onClose,
|
||||
onTraffic: onTraffic,
|
||||
dnsServer: cfg.DNSServer,
|
||||
socksProxyAddr: cfg.SOCKSProxyAddr,
|
||||
socksProxyPort: cfg.SOCKSProxyPort,
|
||||
@@ -268,23 +306,41 @@ func (s *Server) reinstallSession(dead *smux.Session) {
|
||||
s.reinstallMu.Lock()
|
||||
defer s.reinstallMu.Unlock()
|
||||
|
||||
s.sessMu.Lock()
|
||||
if s.session != dead {
|
||||
s.sessMu.Unlock()
|
||||
// Pre-build the replacement so we can swap atomically below.
|
||||
newConn := muxconn.New(s.ln, s.cipher)
|
||||
newSess, err := smux.Server(newConn, smuxConfig())
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
_ = newConn.Close()
|
||||
return
|
||||
}
|
||||
if s.session != nil {
|
||||
_ = s.session.Close()
|
||||
s.session = nil
|
||||
}
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
|
||||
s.sessMu.Lock()
|
||||
if s.session != dead {
|
||||
// Someone else already reinstalled — discard our build.
|
||||
s.sessMu.Unlock()
|
||||
_ = newSess.Close()
|
||||
_ = newConn.Close()
|
||||
return
|
||||
}
|
||||
oldSess := s.session
|
||||
oldConn := s.conn
|
||||
oldSID := s.sessionID
|
||||
s.session = newSess
|
||||
s.conn = newConn
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
s.installSession()
|
||||
|
||||
if oldSess != nil {
|
||||
_ = oldSess.Close()
|
||||
}
|
||||
if oldConn != nil {
|
||||
_ = oldConn.Close()
|
||||
}
|
||||
if oldSID != "" {
|
||||
s.onClose(oldSID, "reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) closeSession() {
|
||||
@@ -297,9 +353,13 @@ func (s *Server) closeSession() {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
oldSID := s.sessionID
|
||||
s.sessionID = ""
|
||||
s.deviceID = ""
|
||||
s.sessMu.Unlock()
|
||||
if oldSID != "" {
|
||||
s.onClose(oldSID, "closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) onData(data []byte) {
|
||||
@@ -393,6 +453,7 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
||||
s.deviceID = hello.DeviceID
|
||||
s.sessionID = sid
|
||||
s.sessMu.Unlock()
|
||||
s.onOpen(sid, hello.DeviceID, hello.Claims)
|
||||
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
||||
// The control stream stays open for the lifetime of the session;
|
||||
// keep it parked in a goroutine so the smux session does not close it.
|
||||
@@ -473,6 +534,10 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
logger.Infof("sid=%d connect %s", stream.ID(), addr)
|
||||
|
||||
s.sessMu.RLock()
|
||||
sid := s.sessionID
|
||||
s.sessMu.RUnlock()
|
||||
|
||||
dialStart := time.Now()
|
||||
conn, err := s.dial(req)
|
||||
dialElapsed := time.Since(dialStart)
|
||||
@@ -489,11 +554,26 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
var bytesOut uint64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, _ = io.Copy(stream, conn)
|
||||
n, _ := io.Copy(stream, conn)
|
||||
if n > 0 {
|
||||
bytesOut = uint64(n) //nolint:gosec // io.Copy returns non-negative int64
|
||||
}
|
||||
_ = stream.Close()
|
||||
close(done)
|
||||
}()
|
||||
_, _ = io.Copy(conn, stream)
|
||||
in, _ := io.Copy(conn, stream)
|
||||
_ = conn.Close()
|
||||
<-done
|
||||
bytesIn := uint64(0)
|
||||
if in > 0 {
|
||||
bytesIn = uint64(in) //nolint:gosec // io.Copy returns non-negative int64
|
||||
}
|
||||
if s.onTraffic != nil {
|
||||
s.onTraffic(sid, addr, bytesIn, bytesOut)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
@@ -344,3 +345,128 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) {
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestReinstallSessionFiresOnClose(t *testing.T) {
|
||||
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
|
||||
if err != nil {
|
||||
t.Fatalf("NewCipher() error = %v", err)
|
||||
}
|
||||
var got struct {
|
||||
sid string
|
||||
reason string
|
||||
}
|
||||
s := &Server{
|
||||
ln: &serverLinkStub{},
|
||||
cipher: cipher,
|
||||
sessionID: "sid-123",
|
||||
deviceID: "dev-123",
|
||||
onClose: func(sid, reason string) { got.sid = sid; got.reason = reason },
|
||||
}
|
||||
s.closeSession()
|
||||
if got.sid != "sid-123" || got.reason != "closed" {
|
||||
t.Fatalf("onClose = %+v, want {sid-123 closed}", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchFiresOnTraffic(t *testing.T) {
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen() error = %v", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
const greeting = "hi\n"
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = c.Close() }()
|
||||
_, _ = c.Write([]byte(greeting))
|
||||
}()
|
||||
|
||||
a, b := net.Pipe()
|
||||
defer func() {
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
serverSess, err := smux.Server(a, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Server() error = %v", err)
|
||||
}
|
||||
defer func() { _ = serverSess.Close() }()
|
||||
clientSess, err := smux.Client(b, smuxConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("smux.Client() error = %v", err)
|
||||
}
|
||||
defer func() { _ = clientSess.Close() }()
|
||||
|
||||
var rec struct {
|
||||
sid string
|
||||
addr string
|
||||
in, out uint64
|
||||
}
|
||||
recChan := make(chan struct{})
|
||||
s := &Server{
|
||||
sessionID: "traffic-sid",
|
||||
resolver: net.DefaultResolver,
|
||||
onTraffic: func(sid, addr string, in, out uint64) {
|
||||
rec.sid = sid
|
||||
rec.addr = addr
|
||||
rec.in = in
|
||||
rec.out = out
|
||||
close(recChan)
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
stream, err := serverSess.AcceptStream()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleStream(context.Background(), stream)
|
||||
}()
|
||||
|
||||
stream, err := clientSess.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream() error = %v", err)
|
||||
}
|
||||
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatalf("addr type = %T", ln.Addr())
|
||||
}
|
||||
req, err := json.Marshal(ConnectRequest{
|
||||
Cmd: "connect",
|
||||
Addr: "127.0.0.1",
|
||||
Port: tcpAddr.Port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
if _, err := stream.Write(req); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
|
||||
ack := make([]byte, 1)
|
||||
if _, err := io.ReadFull(stream, ack); err != nil {
|
||||
t.Fatalf("read ack: %v", err)
|
||||
}
|
||||
body := make([]byte, len(greeting))
|
||||
if _, err := io.ReadFull(stream, body); err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
_ = stream.Close()
|
||||
|
||||
select {
|
||||
case <-recChan:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("onTraffic did not fire")
|
||||
}
|
||||
if rec.sid != "traffic-sid" {
|
||||
t.Fatalf("sid = %q, want traffic-sid", rec.sid)
|
||||
}
|
||||
if rec.out < uint64(len(greeting)) {
|
||||
t.Fatalf("bytesOut = %d, want >= %d", rec.out, len(greeting))
|
||||
}
|
||||
}
|
||||
|
||||
169
pkg/olcrtc/tunnel/tunnel.go
Normal file
169
pkg/olcrtc/tunnel/tunnel.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Package tunnel exposes olcrtc's server-side tunnel as an embeddable Go library.
|
||||
//
|
||||
// A [Server] accepts encrypted tunnel connections over a WebRTC SFU carrier
|
||||
// and proxies their traffic to arbitrary TCP targets. Consumers plug in
|
||||
// authorization and observability via the [Config] hooks:
|
||||
//
|
||||
// srv := tunnel.New(tunnel.Config{
|
||||
// Link: "direct",
|
||||
// Transport: "datachannel",
|
||||
// Carrier: "telemost",
|
||||
// RoomURL: "<room-id>",
|
||||
// KeyHex: "<64-char hex>",
|
||||
// DNSServer: "1.1.1.1:53",
|
||||
// AuthHook: func(deviceID string, claims map[string]any) (string, error) {
|
||||
// // reject unknown devices, enrich session with a DB-issued ID
|
||||
// return db.IssueSession(deviceID, claims)
|
||||
// },
|
||||
// OnSessionOpen: func(sid, dev string, claims map[string]any) {
|
||||
// log.Printf("session %s opened (device=%s)", sid, dev)
|
||||
// },
|
||||
// OnSessionClose: func(sid, reason string) {
|
||||
// log.Printf("session %s closed (%s)", sid, reason)
|
||||
// },
|
||||
// OnTraffic: func(sid, addr string, in, out uint64) {
|
||||
// metrics.Record(sid, addr, in, out)
|
||||
// },
|
||||
// })
|
||||
// if err := srv.Run(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// Call [RegisterDefaults] once at program start to register the built-in
|
||||
// carriers (telemost, jazz, wbstream) and transports (datachannel,
|
||||
// videochannel, seichannel, vp8channel).
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/app/session"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/server"
|
||||
)
|
||||
|
||||
// AuthFunc is invoked after CLIENT_HELLO to authorize the client and issue a
|
||||
// session ID. Returning a non-nil error rejects the handshake; the error's
|
||||
// message is forwarded to the client as the reject reason, so it should not
|
||||
// leak sensitive details.
|
||||
type AuthFunc = handshake.AuthFunc
|
||||
|
||||
// SessionOpenFunc fires right after a successful handshake, before the server
|
||||
// starts accepting tunnel streams on that session.
|
||||
type SessionOpenFunc = server.SessionOpenFunc
|
||||
|
||||
// SessionCloseFunc fires when a session ends. Reasons include "reconnect"
|
||||
// (carrier dropped and was reestablished) and "closed" (graceful shutdown or
|
||||
// ctx cancel).
|
||||
type SessionCloseFunc = server.SessionCloseFunc
|
||||
|
||||
// TrafficFunc fires once per tunnel stream after both copy loops finish.
|
||||
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
|
||||
type TrafficFunc = server.TrafficFunc
|
||||
|
||||
// Config holds runtime server configuration.
|
||||
type Config struct {
|
||||
// --- carrier selection ---
|
||||
Link string // currently only "direct"
|
||||
Transport string // datachannel, videochannel, seichannel, vp8channel
|
||||
Carrier string // telemost, jazz, wbstream, none
|
||||
RoomURL string // conference room identifier for the carrier
|
||||
|
||||
// --- direct engine mode (Carrier == "none") ---
|
||||
Engine string // livekit, goolom, salutejazz
|
||||
URL string
|
||||
Token string
|
||||
|
||||
// --- crypto & networking ---
|
||||
KeyHex string // 64-char hex (32 bytes) shared with the client
|
||||
DNSServer string // resolver used for target dials, e.g. "1.1.1.1:53"
|
||||
SOCKSProxyAddr string // optional outbound SOCKS5 proxy host
|
||||
SOCKSProxyPort int // optional outbound SOCKS5 proxy port
|
||||
|
||||
// --- transport tuning ---
|
||||
VideoWidth int
|
||||
VideoHeight int
|
||||
VideoFPS int
|
||||
VideoBitrate string
|
||||
VideoHW string
|
||||
VideoQRSize int
|
||||
VideoQRRecovery string
|
||||
VideoCodec string
|
||||
VideoTileModule int
|
||||
VideoTileRS int
|
||||
VP8FPS int
|
||||
VP8BatchSize int
|
||||
SEIFPS int
|
||||
SEIBatchSize int
|
||||
SEIFragmentSize int
|
||||
SEIAckTimeoutMS int
|
||||
|
||||
// --- hooks ---
|
||||
// AuthHook authorizes the client. If nil, every client is admitted with a
|
||||
// random UUID as session ID.
|
||||
AuthHook AuthFunc
|
||||
// OnSessionOpen fires after a successful handshake. Nil is a no-op.
|
||||
OnSessionOpen SessionOpenFunc
|
||||
// OnSessionClose fires when the session is torn down. Nil is a no-op.
|
||||
OnSessionClose SessionCloseFunc
|
||||
// OnTraffic fires once per tunnel stream after both copy loops finish.
|
||||
// Nil is a no-op.
|
||||
OnTraffic TrafficFunc
|
||||
}
|
||||
|
||||
// Server is an embeddable tunnel server.
|
||||
type Server struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
// New returns a Server configured by cfg. Call [Server.Run] to start it.
|
||||
func New(cfg Config) *Server {
|
||||
return &Server{cfg: cfg}
|
||||
}
|
||||
|
||||
// Run starts the server and blocks until ctx is cancelled or the carrier ends.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if err := server.Run(ctx, server.Config{
|
||||
Link: s.cfg.Link,
|
||||
Transport: s.cfg.Transport,
|
||||
Carrier: s.cfg.Carrier,
|
||||
RoomURL: s.cfg.RoomURL,
|
||||
Engine: s.cfg.Engine,
|
||||
URL: s.cfg.URL,
|
||||
Token: s.cfg.Token,
|
||||
KeyHex: s.cfg.KeyHex,
|
||||
DNSServer: s.cfg.DNSServer,
|
||||
SOCKSProxyAddr: s.cfg.SOCKSProxyAddr,
|
||||
SOCKSProxyPort: s.cfg.SOCKSProxyPort,
|
||||
VideoWidth: s.cfg.VideoWidth,
|
||||
VideoHeight: s.cfg.VideoHeight,
|
||||
VideoFPS: s.cfg.VideoFPS,
|
||||
VideoBitrate: s.cfg.VideoBitrate,
|
||||
VideoHW: s.cfg.VideoHW,
|
||||
VideoQRSize: s.cfg.VideoQRSize,
|
||||
VideoQRRecovery: s.cfg.VideoQRRecovery,
|
||||
VideoCodec: s.cfg.VideoCodec,
|
||||
VideoTileModule: s.cfg.VideoTileModule,
|
||||
VideoTileRS: s.cfg.VideoTileRS,
|
||||
VP8FPS: s.cfg.VP8FPS,
|
||||
VP8BatchSize: s.cfg.VP8BatchSize,
|
||||
SEIFPS: s.cfg.SEIFPS,
|
||||
SEIBatchSize: s.cfg.SEIBatchSize,
|
||||
SEIFragmentSize: s.cfg.SEIFragmentSize,
|
||||
SEIAckTimeoutMS: s.cfg.SEIAckTimeoutMS,
|
||||
AuthHook: s.cfg.AuthHook,
|
||||
OnSessionOpen: s.cfg.OnSessionOpen,
|
||||
OnSessionClose: s.cfg.OnSessionClose,
|
||||
OnTraffic: s.cfg.OnTraffic,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("tunnel: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterDefaults registers the built-in carriers, links and transports.
|
||||
// Safe to call multiple times.
|
||||
func RegisterDefaults() {
|
||||
session.RegisterDefaults()
|
||||
}
|
||||
50
pkg/olcrtc/tunnel/tunnel_test.go
Normal file
50
pkg/olcrtc/tunnel/tunnel_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package tunnel_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/pkg/olcrtc/tunnel"
|
||||
)
|
||||
|
||||
func TestRun_FailsWithoutKey(t *testing.T) {
|
||||
tunnel.RegisterDefaults()
|
||||
err := tunnel.New(tunnel.Config{
|
||||
Link: "direct",
|
||||
Transport: "datachannel",
|
||||
Carrier: "telemost",
|
||||
RoomURL: "room-1",
|
||||
DNSServer: "1.1.1.1:53",
|
||||
}).Run(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("Run(no key) error = nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_PropagatesAuthHook(t *testing.T) {
|
||||
tunnel.RegisterDefaults()
|
||||
|
||||
sentinel := errors.New("no")
|
||||
var called bool
|
||||
cfg := tunnel.Config{
|
||||
AuthHook: func(string, map[string]any) (string, error) {
|
||||
called = true
|
||||
return "", sentinel
|
||||
},
|
||||
}
|
||||
_ = tunnel.New(cfg).Run(context.Background())
|
||||
// Run bails before ever invoking AuthHook (no key, no carrier wired); this
|
||||
// test exists to pin the public surface and ensure the hook field compiles
|
||||
// against the re-exported handshake.AuthFunc type alias. Behavior coverage
|
||||
// of AuthHook itself lives in internal/handshake tests.
|
||||
_ = called
|
||||
}
|
||||
|
||||
// Compile-time checks: the public type aliases must be assignable.
|
||||
var (
|
||||
_ tunnel.AuthFunc = func(string, map[string]any) (string, error) { return "", nil }
|
||||
_ tunnel.SessionOpenFunc = func(string, string, map[string]any) {}
|
||||
_ tunnel.SessionCloseFunc = func(string, string) {}
|
||||
_ tunnel.TrafficFunc = func(string, string, uint64, uint64) {}
|
||||
)
|
||||
Reference in New Issue
Block a user