diff --git a/internal/client/client.go b/internal/client/client.go index 0b81275..7583b9c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -240,12 +240,21 @@ func openControlStream( sess *smux.Session, deviceID string, claims map[string]any, +) (*smux.Stream, string, error) { + return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout) +} + +func openControlStreamTimeout( + sess *smux.Session, + deviceID string, + claims map[string]any, + timeout time.Duration, ) (*smux.Stream, string, error) { stream, err := sess.OpenStream() if err != nil { return nil, "", fmt.Errorf("open control stream: %w", err) } - _ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout)) + _ = stream.SetDeadline(time.Now().Add(timeout)) sid, err := handshake.Client(stream, deviceID, claims) _ = stream.SetDeadline(time.Time{}) if err != nil { @@ -303,32 +312,71 @@ func smuxConfig() *smux.Config { func (c *Client) handleReconnect() { logger.Infof("client link reconnect - tearing down smux session") + + // Install a fresh muxconn immediately so onData never hits nil while + // the old session is being torn down. tryReopenSession will swap it + // again with its own conn on each attempt. + newConn := muxconn.New(c.ln, c.cipher) + c.sessMu.Lock() - if c.controlStrm != nil { - _ = c.controlStrm.Close() - c.controlStrm = nil - } - if c.session != nil { - _ = c.session.Close() - c.session = nil - } - if c.conn != nil { - _ = c.conn.Close() - c.conn = nil - } + oldControl := c.controlStrm + oldSess := c.session + oldConn := c.conn + c.conn = newConn + c.session = nil + c.controlStrm = nil c.sessionID = "" c.sessMu.Unlock() - c.conn = muxconn.New(c.ln, c.cipher) - sess, err := smux.Client(c.conn, smuxConfig()) - if err != nil { - logger.Warnf("smux re-init failed: %v", err) - return + + if oldControl != nil { + _ = oldControl.Close() } - control, sid, err := openControlStream(sess, c.deviceID, c.claims) + if oldSess != nil { + _ = oldSess.Close() + } + if oldConn != nil { + _ = oldConn.Close() + } + + // Server-side may still be tearing down its own session when our callback + // fires — carriers don't guarantee reconnect callbacks are delivered to both + // peers atomically. Retry the handshake a few times, building a fresh + // muxconn+smux pair on each attempt so a failed smux.Close doesn't corrupt + // the byte stream for subsequent attempts. + const ( + maxAttempts = 5 + attemptDelay = 300 * time.Millisecond + ) + for attempt := 1; attempt <= maxAttempts; attempt++ { + if c.tryReopenSession(attempt) { + return + } + time.Sleep(attemptDelay) + } + logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts) +} + +func (c *Client) tryReopenSession(attempt int) bool { + conn := muxconn.New(c.ln, c.cipher) + + c.sessMu.Lock() + old := c.conn + c.conn = conn + c.sessMu.Unlock() + if old != nil { + _ = old.Close() + } + + sess, err := smux.Client(conn, smuxConfig()) if err != nil { - logger.Warnf("handshake on reconnect failed: %v", err) + logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err) + return false + } + control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second) + if err != nil { + logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err) _ = sess.Close() - return + return false } logger.Infof("session %s reopened (device=%s)", sid, c.deviceID) c.sessMu.Lock() @@ -336,6 +384,7 @@ func (c *Client) handleReconnect() { c.controlStrm = control c.sessionID = sid c.sessMu.Unlock() + return true } func (c *Client) shutdown() { diff --git a/internal/e2e/tunnel_test.go b/internal/e2e/tunnel_test.go index b2aad1b..c6b14bf 100644 --- a/internal/e2e/tunnel_test.go +++ b/internal/e2e/tunnel_test.go @@ -131,9 +131,15 @@ func (r *memoryRoom) triggerReconnect() { } r.mu.Unlock() + var wg sync.WaitGroup for _, stream := range streams { - stream.triggerReconnect() + wg.Add(1) + go func() { + defer wg.Done() + stream.triggerReconnect() + }() } + wg.Wait() } func (r *memoryRoom) triggerEnded(reason string) { diff --git a/internal/muxconn/conn.go b/internal/muxconn/conn.go index bbcbb9c..1bf8a22 100644 --- a/internal/muxconn/conn.go +++ b/internal/muxconn/conn.go @@ -50,6 +50,19 @@ func New(ln link.Link, cipher *crypto.Cipher) *Conn { return c } +// Reset clears any buffered inbound bytes, re-arms a closed conn for writes, +// and unblocks pending Reads so the smux session on top of it exits cleanly. +// Use it when the link stays up but the peer's smux session has been rebuilt: +// the inbound byte stream (now indistinguishable random-looking data) must be +// parsed by the fresh smux state, not the old one. +func (c *Conn) Reset() { + c.mu.Lock() + c.buf = nil + c.closed = false + c.cond.Broadcast() + c.mu.Unlock() +} + // Push hands an encrypted wire payload (one OnData event) to the conn. func (c *Conn) Push(ciphertext []byte) { pt, err := c.cipher.Decrypt(ciphertext) diff --git a/internal/server/server.go b/internal/server/server.go index af20c49..fb6b22c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -36,6 +36,19 @@ var ( ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed") ) +// SessionOpenFunc is called after a successful handshake, before the server +// accepts tunnel streams on that session. +type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any) + +// SessionCloseFunc is called when a session is torn down. Possible reasons: +// "reconnect" (carrier dropped and was reestablished), "closed" (graceful +// shutdown or ctx cancel). +type SessionCloseFunc func(sessionID, reason string) + +// TrafficFunc is called once per tunnel stream, after the copy loops finish. +// bytesIn counts client→target bytes; bytesOut counts target→client bytes. +type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64) + // Server handles incoming tunnel connections and proxies their traffic. type Server struct { ln link.Link @@ -46,6 +59,9 @@ type Server struct { reinstallMu sync.Mutex wg sync.WaitGroup authHook handshake.AuthFunc + onOpen SessionOpenFunc + onClose SessionCloseFunc + onTraffic TrafficFunc deviceID string sessionID string dnsServer string @@ -94,6 +110,13 @@ type Config struct { // AuthHook is invoked after CLIENT_HELLO to authorize the client and // return a session ID. If nil, every client is admitted with a random UUID. AuthHook handshake.AuthFunc + + // OnSessionOpen fires after a successful handshake. Nil means no-op. + OnSessionOpen SessionOpenFunc + // OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op. + OnSessionClose SessionCloseFunc + // OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op. + OnTraffic TrafficFunc } // Run starts the server with the given configuration. @@ -110,10 +133,25 @@ func Run(ctx context.Context, cfg Config) error { if hook == nil { hook = defaultAuthHook } + onOpen := cfg.OnSessionOpen + if onOpen == nil { + onOpen = func(string, string, map[string]any) {} + } + onClose := cfg.OnSessionClose + if onClose == nil { + onClose = func(string, string) {} + } + onTraffic := cfg.OnTraffic + if onTraffic == nil { + onTraffic = func(string, string, uint64, uint64) {} + } s := &Server{ cipher: cipher, authHook: hook, + onOpen: onOpen, + onClose: onClose, + onTraffic: onTraffic, dnsServer: cfg.DNSServer, socksProxyAddr: cfg.SOCKSProxyAddr, socksProxyPort: cfg.SOCKSProxyPort, @@ -268,23 +306,41 @@ func (s *Server) reinstallSession(dead *smux.Session) { s.reinstallMu.Lock() defer s.reinstallMu.Unlock() - s.sessMu.Lock() - if s.session != dead { - s.sessMu.Unlock() + // Pre-build the replacement so we can swap atomically below. + newConn := muxconn.New(s.ln, s.cipher) + newSess, err := smux.Server(newConn, smuxConfig()) + if err != nil { + logger.Warnf("smux server init failed: %v", err) + _ = newConn.Close() return } - if s.session != nil { - _ = s.session.Close() - s.session = nil - } - if s.conn != nil { - _ = s.conn.Close() - s.conn = nil + + s.sessMu.Lock() + if s.session != dead { + // Someone else already reinstalled — discard our build. + s.sessMu.Unlock() + _ = newSess.Close() + _ = newConn.Close() + return } + oldSess := s.session + oldConn := s.conn + oldSID := s.sessionID + s.session = newSess + s.conn = newConn s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() - s.installSession() + + if oldSess != nil { + _ = oldSess.Close() + } + if oldConn != nil { + _ = oldConn.Close() + } + if oldSID != "" { + s.onClose(oldSID, "reconnect") + } } func (s *Server) closeSession() { @@ -297,9 +353,13 @@ func (s *Server) closeSession() { _ = s.conn.Close() s.conn = nil } + oldSID := s.sessionID s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + if oldSID != "" { + s.onClose(oldSID, "closed") + } } func (s *Server) onData(data []byte) { @@ -393,6 +453,7 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { s.deviceID = hello.DeviceID s.sessionID = sid s.sessMu.Unlock() + s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID) // The control stream stays open for the lifetime of the session; // keep it parked in a goroutine so the smux session does not close it. @@ -473,6 +534,10 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port)) logger.Infof("sid=%d connect %s", stream.ID(), addr) + s.sessMu.RLock() + sid := s.sessionID + s.sessMu.RUnlock() + dialStart := time.Now() conn, err := s.dial(req) dialElapsed := time.Since(dialStart) @@ -489,11 +554,26 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) { return } + var bytesOut uint64 + done := make(chan struct{}) go func() { - _, _ = io.Copy(stream, conn) + n, _ := io.Copy(stream, conn) + if n > 0 { + bytesOut = uint64(n) //nolint:gosec // io.Copy returns non-negative int64 + } _ = stream.Close() + close(done) }() - _, _ = io.Copy(conn, stream) + in, _ := io.Copy(conn, stream) + _ = conn.Close() + <-done + bytesIn := uint64(0) + if in > 0 { + bytesIn = uint64(in) //nolint:gosec // io.Copy returns non-negative int64 + } + if s.onTraffic != nil { + s.onTraffic(sid, addr, bytesIn, bytesOut) + } } func (s *Server) dial(req ConnectRequest) (net.Conn, error) { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 1414c68..59c0846 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -9,6 +9,7 @@ import ( "net" "strings" "testing" + "time" cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto" "github.com/openlibrecommunity/olcrtc/internal/muxconn" @@ -344,3 +345,128 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) { } <-done } + +func TestReinstallSessionFiresOnClose(t *testing.T) { + cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + var got struct { + sid string + reason string + } + s := &Server{ + ln: &serverLinkStub{}, + cipher: cipher, + sessionID: "sid-123", + deviceID: "dev-123", + onClose: func(sid, reason string) { got.sid = sid; got.reason = reason }, + } + s.closeSession() + if got.sid != "sid-123" || got.reason != "closed" { + t.Fatalf("onClose = %+v, want {sid-123 closed}", got) + } +} + +func TestDispatchFiresOnTraffic(t *testing.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen() error = %v", err) + } + defer func() { _ = ln.Close() }() + + const greeting = "hi\n" + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer func() { _ = c.Close() }() + _, _ = c.Write([]byte(greeting)) + }() + + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig()) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig()) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + var rec struct { + sid string + addr string + in, out uint64 + } + recChan := make(chan struct{}) + s := &Server{ + sessionID: "traffic-sid", + resolver: net.DefaultResolver, + onTraffic: func(sid, addr string, in, out uint64) { + rec.sid = sid + rec.addr = addr + rec.in = in + rec.out = out + close(recChan) + }, + } + + go func() { + stream, err := serverSess.AcceptStream() + if err != nil { + return + } + s.handleStream(context.Background(), stream) + }() + + stream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + tcpAddr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("addr type = %T", ln.Addr()) + } + req, err := json.Marshal(ConnectRequest{ + Cmd: "connect", + Addr: "127.0.0.1", + Port: tcpAddr.Port, + }) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + if _, err := stream.Write(req); err != nil { + t.Fatalf("Write() error = %v", err) + } + + ack := make([]byte, 1) + if _, err := io.ReadFull(stream, ack); err != nil { + t.Fatalf("read ack: %v", err) + } + body := make([]byte, len(greeting)) + if _, err := io.ReadFull(stream, body); err != nil { + t.Fatalf("read body: %v", err) + } + _ = stream.Close() + + select { + case <-recChan: + case <-time.After(2 * time.Second): + t.Fatal("onTraffic did not fire") + } + if rec.sid != "traffic-sid" { + t.Fatalf("sid = %q, want traffic-sid", rec.sid) + } + if rec.out < uint64(len(greeting)) { + t.Fatalf("bytesOut = %d, want >= %d", rec.out, len(greeting)) + } +} diff --git a/pkg/olcrtc/tunnel/tunnel.go b/pkg/olcrtc/tunnel/tunnel.go new file mode 100644 index 0000000..2eece91 --- /dev/null +++ b/pkg/olcrtc/tunnel/tunnel.go @@ -0,0 +1,169 @@ +// Package tunnel exposes olcrtc's server-side tunnel as an embeddable Go library. +// +// A [Server] accepts encrypted tunnel connections over a WebRTC SFU carrier +// and proxies their traffic to arbitrary TCP targets. Consumers plug in +// authorization and observability via the [Config] hooks: +// +// srv := tunnel.New(tunnel.Config{ +// Link: "direct", +// Transport: "datachannel", +// Carrier: "telemost", +// RoomURL: "", +// KeyHex: "<64-char hex>", +// DNSServer: "1.1.1.1:53", +// AuthHook: func(deviceID string, claims map[string]any) (string, error) { +// // reject unknown devices, enrich session with a DB-issued ID +// return db.IssueSession(deviceID, claims) +// }, +// OnSessionOpen: func(sid, dev string, claims map[string]any) { +// log.Printf("session %s opened (device=%s)", sid, dev) +// }, +// OnSessionClose: func(sid, reason string) { +// log.Printf("session %s closed (%s)", sid, reason) +// }, +// OnTraffic: func(sid, addr string, in, out uint64) { +// metrics.Record(sid, addr, in, out) +// }, +// }) +// if err := srv.Run(ctx); err != nil { +// log.Fatal(err) +// } +// +// Call [RegisterDefaults] once at program start to register the built-in +// carriers (telemost, jazz, wbstream) and transports (datachannel, +// videochannel, seichannel, vp8channel). +package tunnel + +import ( + "context" + "fmt" + + "github.com/openlibrecommunity/olcrtc/internal/app/session" + "github.com/openlibrecommunity/olcrtc/internal/handshake" + "github.com/openlibrecommunity/olcrtc/internal/server" +) + +// AuthFunc is invoked after CLIENT_HELLO to authorize the client and issue a +// session ID. Returning a non-nil error rejects the handshake; the error's +// message is forwarded to the client as the reject reason, so it should not +// leak sensitive details. +type AuthFunc = handshake.AuthFunc + +// SessionOpenFunc fires right after a successful handshake, before the server +// starts accepting tunnel streams on that session. +type SessionOpenFunc = server.SessionOpenFunc + +// SessionCloseFunc fires when a session ends. Reasons include "reconnect" +// (carrier dropped and was reestablished) and "closed" (graceful shutdown or +// ctx cancel). +type SessionCloseFunc = server.SessionCloseFunc + +// TrafficFunc fires once per tunnel stream after both copy loops finish. +// bytesIn counts client→target bytes; bytesOut counts target→client bytes. +type TrafficFunc = server.TrafficFunc + +// Config holds runtime server configuration. +type Config struct { + // --- carrier selection --- + Link string // currently only "direct" + Transport string // datachannel, videochannel, seichannel, vp8channel + Carrier string // telemost, jazz, wbstream, none + RoomURL string // conference room identifier for the carrier + + // --- direct engine mode (Carrier == "none") --- + Engine string // livekit, goolom, salutejazz + URL string + Token string + + // --- crypto & networking --- + KeyHex string // 64-char hex (32 bytes) shared with the client + DNSServer string // resolver used for target dials, e.g. "1.1.1.1:53" + SOCKSProxyAddr string // optional outbound SOCKS5 proxy host + SOCKSProxyPort int // optional outbound SOCKS5 proxy port + + // --- transport tuning --- + VideoWidth int + VideoHeight int + VideoFPS int + VideoBitrate string + VideoHW string + VideoQRSize int + VideoQRRecovery string + VideoCodec string + VideoTileModule int + VideoTileRS int + VP8FPS int + VP8BatchSize int + SEIFPS int + SEIBatchSize int + SEIFragmentSize int + SEIAckTimeoutMS int + + // --- hooks --- + // AuthHook authorizes the client. If nil, every client is admitted with a + // random UUID as session ID. + AuthHook AuthFunc + // OnSessionOpen fires after a successful handshake. Nil is a no-op. + OnSessionOpen SessionOpenFunc + // OnSessionClose fires when the session is torn down. Nil is a no-op. + OnSessionClose SessionCloseFunc + // OnTraffic fires once per tunnel stream after both copy loops finish. + // Nil is a no-op. + OnTraffic TrafficFunc +} + +// Server is an embeddable tunnel server. +type Server struct { + cfg Config +} + +// New returns a Server configured by cfg. Call [Server.Run] to start it. +func New(cfg Config) *Server { + return &Server{cfg: cfg} +} + +// Run starts the server and blocks until ctx is cancelled or the carrier ends. +func (s *Server) Run(ctx context.Context) error { + if err := server.Run(ctx, server.Config{ + Link: s.cfg.Link, + Transport: s.cfg.Transport, + Carrier: s.cfg.Carrier, + RoomURL: s.cfg.RoomURL, + Engine: s.cfg.Engine, + URL: s.cfg.URL, + Token: s.cfg.Token, + KeyHex: s.cfg.KeyHex, + DNSServer: s.cfg.DNSServer, + SOCKSProxyAddr: s.cfg.SOCKSProxyAddr, + SOCKSProxyPort: s.cfg.SOCKSProxyPort, + VideoWidth: s.cfg.VideoWidth, + VideoHeight: s.cfg.VideoHeight, + VideoFPS: s.cfg.VideoFPS, + VideoBitrate: s.cfg.VideoBitrate, + VideoHW: s.cfg.VideoHW, + VideoQRSize: s.cfg.VideoQRSize, + VideoQRRecovery: s.cfg.VideoQRRecovery, + VideoCodec: s.cfg.VideoCodec, + VideoTileModule: s.cfg.VideoTileModule, + VideoTileRS: s.cfg.VideoTileRS, + VP8FPS: s.cfg.VP8FPS, + VP8BatchSize: s.cfg.VP8BatchSize, + SEIFPS: s.cfg.SEIFPS, + SEIBatchSize: s.cfg.SEIBatchSize, + SEIFragmentSize: s.cfg.SEIFragmentSize, + SEIAckTimeoutMS: s.cfg.SEIAckTimeoutMS, + AuthHook: s.cfg.AuthHook, + OnSessionOpen: s.cfg.OnSessionOpen, + OnSessionClose: s.cfg.OnSessionClose, + OnTraffic: s.cfg.OnTraffic, + }); err != nil { + return fmt.Errorf("tunnel: %w", err) + } + return nil +} + +// RegisterDefaults registers the built-in carriers, links and transports. +// Safe to call multiple times. +func RegisterDefaults() { + session.RegisterDefaults() +} diff --git a/pkg/olcrtc/tunnel/tunnel_test.go b/pkg/olcrtc/tunnel/tunnel_test.go new file mode 100644 index 0000000..c1366a0 --- /dev/null +++ b/pkg/olcrtc/tunnel/tunnel_test.go @@ -0,0 +1,50 @@ +package tunnel_test + +import ( + "context" + "errors" + "testing" + + "github.com/openlibrecommunity/olcrtc/pkg/olcrtc/tunnel" +) + +func TestRun_FailsWithoutKey(t *testing.T) { + tunnel.RegisterDefaults() + err := tunnel.New(tunnel.Config{ + Link: "direct", + Transport: "datachannel", + Carrier: "telemost", + RoomURL: "room-1", + DNSServer: "1.1.1.1:53", + }).Run(context.Background()) + if err == nil { + t.Fatal("Run(no key) error = nil") + } +} + +func TestRun_PropagatesAuthHook(t *testing.T) { + tunnel.RegisterDefaults() + + sentinel := errors.New("no") + var called bool + cfg := tunnel.Config{ + AuthHook: func(string, map[string]any) (string, error) { + called = true + return "", sentinel + }, + } + _ = tunnel.New(cfg).Run(context.Background()) + // Run bails before ever invoking AuthHook (no key, no carrier wired); this + // test exists to pin the public surface and ensure the hook field compiles + // against the re-exported handshake.AuthFunc type alias. Behavior coverage + // of AuthHook itself lives in internal/handshake tests. + _ = called +} + +// Compile-time checks: the public type aliases must be assignable. +var ( + _ tunnel.AuthFunc = func(string, map[string]any) (string, error) { return "", nil } + _ tunnel.SessionOpenFunc = func(string, string, map[string]any) {} + _ tunnel.SessionCloseFunc = func(string, string) {} + _ tunnel.TrafficFunc = func(string, string, uint64, uint64) {} +)