feat: add timeout to openControlStream function

This commit is contained in:
zarazaex69
2026-05-13 22:07:34 +03:00
parent bcc6b2ee5c
commit 20f2c1397c
7 changed files with 528 additions and 35 deletions

View File

@@ -240,12 +240,21 @@ func openControlStream(
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout)
}
func openControlStreamTimeout(
sess *smux.Session,
deviceID string,
claims map[string]any,
timeout time.Duration,
) (*smux.Stream, string, error) {
stream, err := sess.OpenStream()
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
_ = stream.SetDeadline(time.Now().Add(timeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
@@ -303,32 +312,71 @@ func smuxConfig() *smux.Config {
func (c *Client) handleReconnect() {
logger.Infof("client link reconnect - tearing down smux session")
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
// again with its own conn on each attempt.
newConn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
if c.controlStrm != nil {
_ = c.controlStrm.Close()
c.controlStrm = nil
}
if c.session != nil {
_ = c.session.Close()
c.session = nil
}
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
oldControl := c.controlStrm
oldSess := c.session
oldConn := c.conn
c.conn = newConn
c.session = nil
c.controlStrm = nil
c.sessionID = ""
c.sessMu.Unlock()
c.conn = muxconn.New(c.ln, c.cipher)
sess, err := smux.Client(c.conn, smuxConfig())
if err != nil {
logger.Warnf("smux re-init failed: %v", err)
return
if oldControl != nil {
_ = oldControl.Close()
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
// Server-side may still be tearing down its own session when our callback
// fires — carriers don't guarantee reconnect callbacks are delivered to both
// peers atomically. Retry the handshake a few times, building a fresh
// muxconn+smux pair on each attempt so a failed smux.Close doesn't corrupt
// the byte stream for subsequent attempts.
const (
maxAttempts = 5
attemptDelay = 300 * time.Millisecond
)
for attempt := 1; attempt <= maxAttempts; attempt++ {
if c.tryReopenSession(attempt) {
return
}
time.Sleep(attemptDelay)
}
logger.Warnf("client reconnect: exhausted %d handshake attempts", maxAttempts)
}
func (c *Client) tryReopenSession(attempt int) bool {
conn := muxconn.New(c.ln, c.cipher)
c.sessMu.Lock()
old := c.conn
c.conn = conn
c.sessMu.Unlock()
if old != nil {
_ = old.Close()
}
sess, err := smux.Client(conn, smuxConfig())
if err != nil {
logger.Warnf("handshake on reconnect failed: %v", err)
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
}
control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second)
if err != nil {
logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err)
_ = sess.Close()
return
return false
}
logger.Infof("session %s reopened (device=%s)", sid, c.deviceID)
c.sessMu.Lock()
@@ -336,6 +384,7 @@ func (c *Client) handleReconnect() {
c.controlStrm = control
c.sessionID = sid
c.sessMu.Unlock()
return true
}
func (c *Client) shutdown() {

View File

@@ -131,9 +131,15 @@ func (r *memoryRoom) triggerReconnect() {
}
r.mu.Unlock()
var wg sync.WaitGroup
for _, stream := range streams {
stream.triggerReconnect()
wg.Add(1)
go func() {
defer wg.Done()
stream.triggerReconnect()
}()
}
wg.Wait()
}
func (r *memoryRoom) triggerEnded(reason string) {

View File

@@ -50,6 +50,19 @@ func New(ln link.Link, cipher *crypto.Cipher) *Conn {
return c
}
// Reset clears any buffered inbound bytes, re-arms a closed conn for writes,
// and unblocks pending Reads so the smux session on top of it exits cleanly.
// Use it when the link stays up but the peer's smux session has been rebuilt:
// the inbound byte stream (now indistinguishable random-looking data) must be
// parsed by the fresh smux state, not the old one.
func (c *Conn) Reset() {
c.mu.Lock()
c.buf = nil
c.closed = false
c.cond.Broadcast()
c.mu.Unlock()
}
// Push hands an encrypted wire payload (one OnData event) to the conn.
func (c *Conn) Push(ciphertext []byte) {
pt, err := c.cipher.Decrypt(ciphertext)

View File

@@ -36,6 +36,19 @@ var (
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
)
// SessionOpenFunc is called after a successful handshake, before the server
// accepts tunnel streams on that session.
type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any)
// SessionCloseFunc is called when a session is torn down. Possible reasons:
// "reconnect" (carrier dropped and was reestablished), "closed" (graceful
// shutdown or ctx cancel).
type SessionCloseFunc func(sessionID, reason string)
// TrafficFunc is called once per tunnel stream, after the copy loops finish.
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
ln link.Link
@@ -46,6 +59,9 @@ type Server struct {
reinstallMu sync.Mutex
wg sync.WaitGroup
authHook handshake.AuthFunc
onOpen SessionOpenFunc
onClose SessionCloseFunc
onTraffic TrafficFunc
deviceID string
sessionID string
dnsServer string
@@ -94,6 +110,13 @@ type Config struct {
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
AuthHook handshake.AuthFunc
// OnSessionOpen fires after a successful handshake. Nil means no-op.
OnSessionOpen SessionOpenFunc
// OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op.
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
OnTraffic TrafficFunc
}
// Run starts the server with the given configuration.
@@ -110,10 +133,25 @@ func Run(ctx context.Context, cfg Config) error {
if hook == nil {
hook = defaultAuthHook
}
onOpen := cfg.OnSessionOpen
if onOpen == nil {
onOpen = func(string, string, map[string]any) {}
}
onClose := cfg.OnSessionClose
if onClose == nil {
onClose = func(string, string) {}
}
onTraffic := cfg.OnTraffic
if onTraffic == nil {
onTraffic = func(string, string, uint64, uint64) {}
}
s := &Server{
cipher: cipher,
authHook: hook,
onOpen: onOpen,
onClose: onClose,
onTraffic: onTraffic,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
@@ -268,23 +306,41 @@ func (s *Server) reinstallSession(dead *smux.Session) {
s.reinstallMu.Lock()
defer s.reinstallMu.Unlock()
s.sessMu.Lock()
if s.session != dead {
s.sessMu.Unlock()
// Pre-build the replacement so we can swap atomically below.
newConn := muxconn.New(s.ln, s.cipher)
newSess, err := smux.Server(newConn, smuxConfig())
if err != nil {
logger.Warnf("smux server init failed: %v", err)
_ = newConn.Close()
return
}
if s.session != nil {
_ = s.session.Close()
s.session = nil
}
if s.conn != nil {
_ = s.conn.Close()
s.conn = nil
s.sessMu.Lock()
if s.session != dead {
// Someone else already reinstalled — discard our build.
s.sessMu.Unlock()
_ = newSess.Close()
_ = newConn.Close()
return
}
oldSess := s.session
oldConn := s.conn
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
s.installSession()
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
if oldSID != "" {
s.onClose(oldSID, "reconnect")
}
}
func (s *Server) closeSession() {
@@ -297,9 +353,13 @@ func (s *Server) closeSession() {
_ = s.conn.Close()
s.conn = nil
}
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if oldSID != "" {
s.onClose(oldSID, "closed")
}
}
func (s *Server) onData(data []byte) {
@@ -393,6 +453,7 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
// The control stream stays open for the lifetime of the session;
// keep it parked in a goroutine so the smux session does not close it.
@@ -473,6 +534,10 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
logger.Infof("sid=%d connect %s", stream.ID(), addr)
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
@@ -489,11 +554,26 @@ func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
return
}
var bytesOut uint64
done := make(chan struct{})
go func() {
_, _ = io.Copy(stream, conn)
n, _ := io.Copy(stream, conn)
if n > 0 {
bytesOut = uint64(n) //nolint:gosec // io.Copy returns non-negative int64
}
_ = stream.Close()
close(done)
}()
_, _ = io.Copy(conn, stream)
in, _ := io.Copy(conn, stream)
_ = conn.Close()
<-done
bytesIn := uint64(0)
if in > 0 {
bytesIn = uint64(in) //nolint:gosec // io.Copy returns non-negative int64
}
if s.onTraffic != nil {
s.onTraffic(sid, addr, bytesIn, bytesOut)
}
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {

View File

@@ -9,6 +9,7 @@ import (
"net"
"strings"
"testing"
"time"
cryptopkg "github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
@@ -344,3 +345,128 @@ func TestHandleStreamDispatchAfterConnect(t *testing.T) {
}
<-done
}
func TestReinstallSessionFiresOnClose(t *testing.T) {
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
var got struct {
sid string
reason string
}
s := &Server{
ln: &serverLinkStub{},
cipher: cipher,
sessionID: "sid-123",
deviceID: "dev-123",
onClose: func(sid, reason string) { got.sid = sid; got.reason = reason },
}
s.closeSession()
if got.sid != "sid-123" || got.reason != "closed" {
t.Fatalf("onClose = %+v, want {sid-123 closed}", got)
}
}
func TestDispatchFiresOnTraffic(t *testing.T) {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen() error = %v", err)
}
defer func() { _ = ln.Close() }()
const greeting = "hi\n"
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer func() { _ = c.Close() }()
_, _ = c.Write([]byte(greeting))
}()
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig())
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig())
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
var rec struct {
sid string
addr string
in, out uint64
}
recChan := make(chan struct{})
s := &Server{
sessionID: "traffic-sid",
resolver: net.DefaultResolver,
onTraffic: func(sid, addr string, in, out uint64) {
rec.sid = sid
rec.addr = addr
rec.in = in
rec.out = out
close(recChan)
},
}
go func() {
stream, err := serverSess.AcceptStream()
if err != nil {
return
}
s.handleStream(context.Background(), stream)
}()
stream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatalf("addr type = %T", ln.Addr())
}
req, err := json.Marshal(ConnectRequest{
Cmd: "connect",
Addr: "127.0.0.1",
Port: tcpAddr.Port,
})
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}
if _, err := stream.Write(req); err != nil {
t.Fatalf("Write() error = %v", err)
}
ack := make([]byte, 1)
if _, err := io.ReadFull(stream, ack); err != nil {
t.Fatalf("read ack: %v", err)
}
body := make([]byte, len(greeting))
if _, err := io.ReadFull(stream, body); err != nil {
t.Fatalf("read body: %v", err)
}
_ = stream.Close()
select {
case <-recChan:
case <-time.After(2 * time.Second):
t.Fatal("onTraffic did not fire")
}
if rec.sid != "traffic-sid" {
t.Fatalf("sid = %q, want traffic-sid", rec.sid)
}
if rec.out < uint64(len(greeting)) {
t.Fatalf("bytesOut = %d, want >= %d", rec.out, len(greeting))
}
}

169
pkg/olcrtc/tunnel/tunnel.go Normal file
View File

@@ -0,0 +1,169 @@
// Package tunnel exposes olcrtc's server-side tunnel as an embeddable Go library.
//
// A [Server] accepts encrypted tunnel connections over a WebRTC SFU carrier
// and proxies their traffic to arbitrary TCP targets. Consumers plug in
// authorization and observability via the [Config] hooks:
//
// srv := tunnel.New(tunnel.Config{
// Link: "direct",
// Transport: "datachannel",
// Carrier: "telemost",
// RoomURL: "<room-id>",
// KeyHex: "<64-char hex>",
// DNSServer: "1.1.1.1:53",
// AuthHook: func(deviceID string, claims map[string]any) (string, error) {
// // reject unknown devices, enrich session with a DB-issued ID
// return db.IssueSession(deviceID, claims)
// },
// OnSessionOpen: func(sid, dev string, claims map[string]any) {
// log.Printf("session %s opened (device=%s)", sid, dev)
// },
// OnSessionClose: func(sid, reason string) {
// log.Printf("session %s closed (%s)", sid, reason)
// },
// OnTraffic: func(sid, addr string, in, out uint64) {
// metrics.Record(sid, addr, in, out)
// },
// })
// if err := srv.Run(ctx); err != nil {
// log.Fatal(err)
// }
//
// Call [RegisterDefaults] once at program start to register the built-in
// carriers (telemost, jazz, wbstream) and transports (datachannel,
// videochannel, seichannel, vp8channel).
package tunnel
import (
"context"
"fmt"
"github.com/openlibrecommunity/olcrtc/internal/app/session"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/server"
)
// AuthFunc is invoked after CLIENT_HELLO to authorize the client and issue a
// session ID. Returning a non-nil error rejects the handshake; the error's
// message is forwarded to the client as the reject reason, so it should not
// leak sensitive details.
type AuthFunc = handshake.AuthFunc
// SessionOpenFunc fires right after a successful handshake, before the server
// starts accepting tunnel streams on that session.
type SessionOpenFunc = server.SessionOpenFunc
// SessionCloseFunc fires when a session ends. Reasons include "reconnect"
// (carrier dropped and was reestablished) and "closed" (graceful shutdown or
// ctx cancel).
type SessionCloseFunc = server.SessionCloseFunc
// TrafficFunc fires once per tunnel stream after both copy loops finish.
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc = server.TrafficFunc
// Config holds runtime server configuration.
type Config struct {
// --- carrier selection ---
Link string // currently only "direct"
Transport string // datachannel, videochannel, seichannel, vp8channel
Carrier string // telemost, jazz, wbstream, none
RoomURL string // conference room identifier for the carrier
// --- direct engine mode (Carrier == "none") ---
Engine string // livekit, goolom, salutejazz
URL string
Token string
// --- crypto & networking ---
KeyHex string // 64-char hex (32 bytes) shared with the client
DNSServer string // resolver used for target dials, e.g. "1.1.1.1:53"
SOCKSProxyAddr string // optional outbound SOCKS5 proxy host
SOCKSProxyPort int // optional outbound SOCKS5 proxy port
// --- transport tuning ---
VideoWidth int
VideoHeight int
VideoFPS int
VideoBitrate string
VideoHW string
VideoQRSize int
VideoQRRecovery string
VideoCodec string
VideoTileModule int
VideoTileRS int
VP8FPS int
VP8BatchSize int
SEIFPS int
SEIBatchSize int
SEIFragmentSize int
SEIAckTimeoutMS int
// --- hooks ---
// AuthHook authorizes the client. If nil, every client is admitted with a
// random UUID as session ID.
AuthHook AuthFunc
// OnSessionOpen fires after a successful handshake. Nil is a no-op.
OnSessionOpen SessionOpenFunc
// OnSessionClose fires when the session is torn down. Nil is a no-op.
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish.
// Nil is a no-op.
OnTraffic TrafficFunc
}
// Server is an embeddable tunnel server.
type Server struct {
cfg Config
}
// New returns a Server configured by cfg. Call [Server.Run] to start it.
func New(cfg Config) *Server {
return &Server{cfg: cfg}
}
// Run starts the server and blocks until ctx is cancelled or the carrier ends.
func (s *Server) Run(ctx context.Context) error {
if err := server.Run(ctx, server.Config{
Link: s.cfg.Link,
Transport: s.cfg.Transport,
Carrier: s.cfg.Carrier,
RoomURL: s.cfg.RoomURL,
Engine: s.cfg.Engine,
URL: s.cfg.URL,
Token: s.cfg.Token,
KeyHex: s.cfg.KeyHex,
DNSServer: s.cfg.DNSServer,
SOCKSProxyAddr: s.cfg.SOCKSProxyAddr,
SOCKSProxyPort: s.cfg.SOCKSProxyPort,
VideoWidth: s.cfg.VideoWidth,
VideoHeight: s.cfg.VideoHeight,
VideoFPS: s.cfg.VideoFPS,
VideoBitrate: s.cfg.VideoBitrate,
VideoHW: s.cfg.VideoHW,
VideoQRSize: s.cfg.VideoQRSize,
VideoQRRecovery: s.cfg.VideoQRRecovery,
VideoCodec: s.cfg.VideoCodec,
VideoTileModule: s.cfg.VideoTileModule,
VideoTileRS: s.cfg.VideoTileRS,
VP8FPS: s.cfg.VP8FPS,
VP8BatchSize: s.cfg.VP8BatchSize,
SEIFPS: s.cfg.SEIFPS,
SEIBatchSize: s.cfg.SEIBatchSize,
SEIFragmentSize: s.cfg.SEIFragmentSize,
SEIAckTimeoutMS: s.cfg.SEIAckTimeoutMS,
AuthHook: s.cfg.AuthHook,
OnSessionOpen: s.cfg.OnSessionOpen,
OnSessionClose: s.cfg.OnSessionClose,
OnTraffic: s.cfg.OnTraffic,
}); err != nil {
return fmt.Errorf("tunnel: %w", err)
}
return nil
}
// RegisterDefaults registers the built-in carriers, links and transports.
// Safe to call multiple times.
func RegisterDefaults() {
session.RegisterDefaults()
}

View File

@@ -0,0 +1,50 @@
package tunnel_test
import (
"context"
"errors"
"testing"
"github.com/openlibrecommunity/olcrtc/pkg/olcrtc/tunnel"
)
func TestRun_FailsWithoutKey(t *testing.T) {
tunnel.RegisterDefaults()
err := tunnel.New(tunnel.Config{
Link: "direct",
Transport: "datachannel",
Carrier: "telemost",
RoomURL: "room-1",
DNSServer: "1.1.1.1:53",
}).Run(context.Background())
if err == nil {
t.Fatal("Run(no key) error = nil")
}
}
func TestRun_PropagatesAuthHook(t *testing.T) {
tunnel.RegisterDefaults()
sentinel := errors.New("no")
var called bool
cfg := tunnel.Config{
AuthHook: func(string, map[string]any) (string, error) {
called = true
return "", sentinel
},
}
_ = tunnel.New(cfg).Run(context.Background())
// Run bails before ever invoking AuthHook (no key, no carrier wired); this
// test exists to pin the public surface and ensure the hook field compiles
// against the re-exported handshake.AuthFunc type alias. Behavior coverage
// of AuthHook itself lives in internal/handshake tests.
_ = called
}
// Compile-time checks: the public type aliases must be assignable.
var (
_ tunnel.AuthFunc = func(string, map[string]any) (string, error) { return "", nil }
_ tunnel.SessionOpenFunc = func(string, string, map[string]any) {}
_ tunnel.SessionCloseFunc = func(string, string) {}
_ tunnel.TrafficFunc = func(string, string, uint64, uint64) {}
)