From c6c301c0587a89ef6e62bc113124152b79c6bd68 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Sun, 17 May 2026 18:35:05 +0300 Subject: [PATCH] fix: handle graceful control shutdown and reconnects --- cmd/olcrtc/main.go | 2 +- internal/client/client.go | 43 ++++++++++++- internal/client/client_test.go | 73 ++++++++++++++++++---- internal/control/control.go | 13 +++- internal/control/control_test.go | 20 ++++++ internal/runtime/runtime.go | 10 +-- internal/runtime/runtime_test.go | 4 ++ internal/server/server.go | 23 +++++++ internal/server/server_test.go | 101 +++++++++++++++++++++++++++---- 9 files changed, 256 insertions(+), 33 deletions(-) diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 45662af..1ff7b8e 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -334,13 +334,13 @@ func (f filteredWriter) Write(p []byte) (int, error) { } func configureLogging(debug bool) { + log.SetOutput(filteredWriter{w: os.Stderr}) if debug { logger.SetVerbose(true) return } _ = os.Setenv("PION_LOG_DISABLE", "all") lksdk.SetLogger(protoLogger.GetDiscardLogger()) - log.SetOutput(filteredWriter{w: os.Stderr}) } func resolveDataDir(dataDir string) (string, error) { diff --git a/internal/client/client.go b/internal/client/client.go index cfd489e..5a92388 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -217,7 +217,7 @@ func (c *Client) bringUpLink( return fmt.Errorf("smux client: %w", err) } - control, sid, err := openControlStream(sess, c.deviceID, c.claims) + control, sid, err := openControlStream(ctx, sess, c.deviceID, c.claims) if err != nil { _ = sess.Close() _ = c.conn.Close() @@ -241,14 +241,16 @@ func (c *Client) bringUpLink( // The stream stays open for the lifetime of the smux session and carries // post-handshake control messages. func openControlStream( + ctx context.Context, sess *smux.Session, deviceID string, claims map[string]any, ) (*smux.Stream, string, error) { - return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout) + return openControlStreamTimeout(ctx, sess, deviceID, claims, handshake.DefaultTimeout) } func openControlStreamTimeout( + ctx context.Context, sess *smux.Session, deviceID string, claims map[string]any, @@ -258,11 +260,23 @@ func openControlStreamTimeout( if err != nil { return nil, "", fmt.Errorf("open control stream: %w", err) } + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = stream.Close() + case <-done: + } + }() + defer close(done) _ = stream.SetDeadline(time.Now().Add(timeout)) sid, err := handshake.Client(stream, deviceID, claims) _ = stream.SetDeadline(time.Time{}) if err != nil { _ = stream.Close() + if ctx.Err() != nil { + return nil, "", fmt.Errorf("handshake client: %w", ctx.Err()) + } return nil, "", fmt.Errorf("handshake client: %w", err) } return stream, sid, nil @@ -316,6 +330,7 @@ func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context c.recordReconnect() logger.Infof("client reconnect reason=%s - tearing down smux session", reason) + c.resetLinkPeer() // Install a fresh muxconn immediately so onData never hits nil while // the old session is being torn down. tryReopenSession will swap it @@ -371,6 +386,15 @@ func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context return false } +func (c *Client) resetLinkPeer() { + c.sessMu.RLock() + ln := c.ln + c.sessMu.RUnlock() + if resetter, ok := ln.(interface{ ResetPeer() }); ok { + resetter.ResetPeer() + } +} + func (c *Client) tryReopenSession( ctx context.Context, cfg Config, @@ -392,7 +416,7 @@ func (c *Client) tryReopenSession( logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err) return false } - control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second) + control, sid, err := openControlStreamTimeout(ctx, sess, c.deviceID, c.claims, 2*time.Second) if err != nil { logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err) _ = sess.Close() @@ -486,6 +510,7 @@ func (c *Client) shutdown() { c.conn = nil c.sessMu.Unlock() + notifyControlClose(control) if controlStop != nil { controlStop() } @@ -503,6 +528,18 @@ func (c *Client) shutdown() { } } +func notifyControlClose(stream *smux.Stream) { + if stream == nil { + return + } + _ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second)) + if err := control.SendClose(stream); err == nil { + time.Sleep(200 * time.Millisecond) + } + _ = stream.SetWriteDeadline(time.Time{}) + _ = stream.CloseWrite() +} + func setupCipher(keyHex string) (*crypto.Cipher, error) { cipher, err := runtime.SetupCipher(keyHex) if err != nil { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 590d63e..ed249d7 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -48,7 +48,7 @@ func TestSetupCipherRejectsBadInput(t *testing.T) { func TestSmuxConfig(t *testing.T) { cfg := smuxConfig(0) - if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { + if cfg.Version != 2 || cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig(0) = %+v", cfg) } capped := smuxConfig(4096) @@ -491,19 +491,59 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) { } } -type closerLinkStub struct { - closed bool +func TestOpenControlStreamStopsOnContextCancel(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig(0)) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + defer func() { _ = serverSess.Close() }() + clientSess, err := smux.Client(b, smuxConfig(0)) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + defer func() { _ = clientSess.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + _, _, err := openControlStreamTimeout(ctx, clientSess, "dev", nil, time.Hour) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-errCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("openControlStreamTimeout() error = %v, want context.Canceled", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for context cancellation") + } } -func (s *closerLinkStub) Connect(context.Context) error { return nil } -func (s *closerLinkStub) Send([]byte) error { return nil } -func (s *closerLinkStub) Close() error { s.closed = true; return nil } -func (s *closerLinkStub) SetReconnectCallback(func()) {} -func (s *closerLinkStub) SetShouldReconnect(func() bool) {} -func (s *closerLinkStub) SetEndedCallback(func(string)) {} -func (s *closerLinkStub) WatchConnection(context.Context) {} -func (s *closerLinkStub) CanSend() bool { return true } -func (s *closerLinkStub) Features() transport.Features { return transport.Features{} } +type closerLinkStub struct { + closed bool + resetCount int +} + +func (s *closerLinkStub) Connect(context.Context) error { return nil } +func (s *closerLinkStub) Send([]byte) error { return nil } +func (s *closerLinkStub) Close() error { s.closed = true; return nil } +func (s *closerLinkStub) SetReconnectCallback(func()) {} +func (s *closerLinkStub) SetShouldReconnect(func() bool) {} +func (s *closerLinkStub) SetEndedCallback(func(string)) {} +func (s *closerLinkStub) WatchConnection(context.Context) {} +func (s *closerLinkStub) CanSend() bool { return true } +func (s *closerLinkStub) Features() transport.Features { return transport.Features{} } +func (s *closerLinkStub) ResetPeer() { s.resetCount++ } func TestOnDataWithNilConn(_ *testing.T) { c := &Client{} @@ -527,6 +567,15 @@ func TestShutdownClosesLinkAndConn(t *testing.T) { } } +func TestResetLinkPeer(t *testing.T) { + ln := &closerLinkStub{} + c := &Client{ln: ln} + c.resetLinkPeer() + if ln.resetCount != 1 { + t.Fatalf("ResetPeer calls = %d, want 1", ln.resetCount) + } +} + //nolint:cyclop // integration-style control loop test needs setup and async assertions together func TestStartControlLoopReportsPong(t *testing.T) { a, b := net.Pipe() diff --git a/internal/control/control.go b/internal/control/control.go index d208afb..de4f521 100644 --- a/internal/control/control.go +++ b/internal/control/control.go @@ -44,11 +44,15 @@ const ( TypePing MsgType = "CONTROL_PING" // TypePong replies to a ping with the same sequence and timestamp. TypePong MsgType = "CONTROL_PONG" + // TypeClose tells the peer this control session is intentionally closing. + TypeClose MsgType = "CONTROL_CLOSE" ) var ( // ErrUnhealthy is returned when the stream misses too many pong replies. ErrUnhealthy = errors.New("control stream unhealthy") + // ErrClosedByPeer is returned when the peer gracefully closes the control session. + ErrClosedByPeer = errors.New("control stream closed by peer") // ErrProtocolVersion is returned when the peer announces an incompatible version. ErrProtocolVersion = errors.New("incompatible control protocol version") // ErrUnexpectedMessage is returned for unknown or malformed control message types. @@ -184,6 +188,8 @@ func (s *state) readLoop(ctx context.Context) error { } case TypePong: s.handlePong(msg) + case TypeClose: + return ErrClosedByPeer default: return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) } @@ -302,12 +308,17 @@ func parseMessage(raw []byte) (Message, error) { return Message{}, fmt.Errorf("%w: peer v%d, local v%d", ErrProtocolVersion, msg.Version, ProtoVersion) } - if msg.Type != TypePing && msg.Type != TypePong { + if msg.Type != TypePing && msg.Type != TypePong && msg.Type != TypeClose { return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type) } return msg, nil } +// SendClose sends a best-effort graceful close notification on the control stream. +func SendClose(w io.Writer) error { + return writeFrame(w, Message{Version: ProtoVersion, Type: TypeClose}) +} + func writeFrame(w io.Writer, msg Message) error { if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil { return fmt.Errorf("control: %w", err) diff --git a/internal/control/control_test.go b/internal/control/control_test.go index 8700027..ea65503 100644 --- a/internal/control/control_test.go +++ b/internal/control/control_test.go @@ -124,6 +124,26 @@ func TestRunRejectsBadProtocolVersion(t *testing.T) { } } +func TestRunStopsOnPeerClose(t *testing.T) { + a, b := controlPair(t) + errCh := make(chan error, 1) + go func() { + errCh <- Run(context.Background(), a, Config{Interval: time.Hour}) + }() + if err := SendClose(b); err != nil { + t.Fatalf("SendClose() error = %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, ErrClosedByPeer) { + t.Fatalf("Run() error = %v, want ErrClosedByPeer", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for peer close") + } +} + func TestReadFrameRejectsTooLarge(t *testing.T) { a, b := controlPair(t) go func() { diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 1f9b838..8a2af3e 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -49,7 +49,7 @@ func SetupCipher(keyHex string) (*crypto.Cipher, error) { func SmuxConfig(maxWirePayload int) *smux.Config { cfg := smux.DefaultConfig() cfg.Version = 2 - cfg.KeepAliveDisabled = true + cfg.KeepAliveDisabled = false cfg.MaxFrameSize = 32768 if maxWirePayload > crypto.WireOverhead { maxFrameSize := maxWirePayload - crypto.WireOverhead @@ -60,7 +60,7 @@ func SmuxConfig(maxWirePayload int) *smux.Config { cfg.MaxReceiveBuffer = 16 * 1024 * 1024 cfg.MaxStreamBuffer = 1024 * 1024 cfg.KeepAliveInterval = 10 * time.Second - cfg.KeepAliveTimeout = 60 * time.Second + cfg.KeepAliveTimeout = 30 * time.Second return cfg } @@ -76,9 +76,9 @@ func MaxPayload(tr transport.Transport) int { // Server and client both embed a HealthTracker to avoid open-coding the // same record* methods on both sides. type HealthTracker struct { - mu sync.RWMutex - status control.Status - notify func(control.Status) + mu sync.RWMutex + status control.Status + notify func(control.Status) } // NewHealthTracker creates a HealthTracker that publishes the latest diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index a0f44eb..7d18bbe 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -37,6 +37,10 @@ func TestSmuxConfigDefault(t *testing.T) { if cfg.Version != 2 || cfg.MaxFrameSize != 32768 { t.Fatalf("SmuxConfig(0) = %+v", cfg) } + if cfg.KeepAliveDisabled || cfg.KeepAliveInterval != 10*time.Second || + cfg.KeepAliveTimeout != 30*time.Second { + t.Fatalf("SmuxConfig(0) keepalive = %+v", cfg) + } } func TestSmuxConfigShrinks(t *testing.T) { diff --git a/internal/server/server.go b/internal/server/server.go index 53e35eb..84d7299 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -60,6 +60,7 @@ type Server struct { cipher *crypto.Cipher conn *muxconn.Conn session *smux.Session + controlStrm *smux.Stream controlStop context.CancelFunc sessMu sync.RWMutex reinstallMu sync.Mutex @@ -307,10 +308,12 @@ func (s *Server) reinstallSession(dead *smux.Session) { } oldSess := s.session oldConn := s.conn + oldControl := s.controlStrm oldControlStop := s.controlStop oldSID := s.sessionID s.session = newSess s.conn = newConn + s.controlStrm = nil s.controlStop = nil s.sessionID = "" s.deviceID = "" @@ -325,6 +328,9 @@ func (s *Server) reinstallSession(dead *smux.Session) { if oldConn != nil { _ = oldConn.Close() } + if oldControl != nil { + _ = oldControl.Close() + } if oldSID != "" { s.onClose(oldSID, "reconnect") } @@ -334,15 +340,18 @@ func (s *Server) closeSession() { s.sessMu.Lock() sess := s.session conn := s.conn + control := s.controlStrm controlStop := s.controlStop s.session = nil s.conn = nil + s.controlStrm = nil s.controlStop = nil oldSID := s.sessionID s.sessionID = "" s.deviceID = "" s.sessMu.Unlock() + notifyControlClose(control) if controlStop != nil { controlStop() } @@ -357,6 +366,18 @@ func (s *Server) closeSession() { } } +func notifyControlClose(stream *smux.Stream) { + if stream == nil { + return + } + _ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second)) + if err := control.SendClose(stream); err == nil { + time.Sleep(200 * time.Millisecond) + } + _ = stream.SetWriteDeadline(time.Time{}) + _ = stream.CloseWrite() +} + func (s *Server) onData(data []byte) { s.sessMu.RLock() conn := s.conn @@ -474,6 +495,7 @@ func (s *Server) resetLinkPeer() { func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) { controlCtx, stop := context.WithCancel(ctx) s.sessMu.Lock() + s.controlStrm = stream s.controlStop = stop s.sessMu.Unlock() @@ -519,6 +541,7 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea } s.recordReconnect() logger.Infof("server reconnect reason=liveness - reinstalling smux session") + s.resetLinkPeer() s.reinstallSession(sess) }() } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 9512f8d..ac805d8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -49,7 +49,7 @@ func TestSetupCipherRejectsBadInput(t *testing.T) { func TestSmuxConfig(t *testing.T) { cfg := smuxConfig(0) - if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { + if cfg.Version != 2 || cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 { t.Fatalf("smuxConfig(0) = %+v", cfg) } capped := smuxConfig(4096) @@ -211,18 +211,29 @@ func TestOnDataWithNilConn(_ *testing.T) { } type serverLinkStub struct { - closed bool + closed bool + resetCount int + resetCh chan struct{} } -func (s *serverLinkStub) Connect(context.Context) error { return nil } -func (s *serverLinkStub) Send([]byte) error { return nil } -func (s *serverLinkStub) Close() error { s.closed = true; return nil } -func (s *serverLinkStub) SetReconnectCallback(func()) {} -func (s *serverLinkStub) SetShouldReconnect(func() bool) {} -func (s *serverLinkStub) SetEndedCallback(func(string)) {} -func (s *serverLinkStub) WatchConnection(context.Context) {} -func (s *serverLinkStub) CanSend() bool { return true } -func (s *serverLinkStub) Features() transport.Features { return transport.Features{} } +func (s *serverLinkStub) Connect(context.Context) error { return nil } +func (s *serverLinkStub) Send([]byte) error { return nil } +func (s *serverLinkStub) Close() error { s.closed = true; return nil } +func (s *serverLinkStub) SetReconnectCallback(func()) {} +func (s *serverLinkStub) SetShouldReconnect(func() bool) {} +func (s *serverLinkStub) SetEndedCallback(func(string)) {} +func (s *serverLinkStub) WatchConnection(context.Context) {} +func (s *serverLinkStub) CanSend() bool { return true } +func (s *serverLinkStub) Features() transport.Features { return transport.Features{} } +func (s *serverLinkStub) ResetPeer() { + s.resetCount++ + if s.resetCh != nil { + select { + case s.resetCh <- struct{}{}: + default: + } + } +} func TestShutdownClosesLinkAndConn(t *testing.T) { cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901") @@ -463,6 +474,74 @@ func TestStartControlLoopReportsPong(t *testing.T) { } } +func TestStartControlLoopResetsPeerBeforeReinstall(t *testing.T) { + a, b := net.Pipe() + defer func() { + _ = a.Close() + _ = b.Close() + }() + + serverSess, err := smux.Server(a, smuxConfig(0)) + if err != nil { + t.Fatalf("smux.Server() error = %v", err) + } + clientSess, err := smux.Client(b, smuxConfig(0)) + if err != nil { + t.Fatalf("smux.Client() error = %v", err) + } + + serverStreamCh := make(chan *smux.Stream, 1) + go func() { + stream, err := serverSess.AcceptStream() + if err == nil { + serverStreamCh <- stream + } + }() + + clientStream, err := clientSess.OpenStream() + if err != nil { + t.Fatalf("OpenStream() error = %v", err) + } + serverStream := <-serverStreamCh + + cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901") + if err != nil { + t.Fatalf("NewCipher() error = %v", err) + } + ln := &serverLinkStub{resetCh: make(chan struct{}, 1)} + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + ln: ln, + cipher: cipher, + conn: muxconn.New(ln, cipher), + session: serverSess, + health: runtime.NewHealthTracker(nil), + liveness: control.Config{ + Interval: time.Hour, + Timeout: time.Hour, + Failures: 1, + }, + } + defer func() { + cancel() + s.shutdown() + s.wg.Wait() + _ = clientSess.Close() + }() + + s.startControlLoop(ctx, serverSess, serverStream) + _ = clientStream.Close() + + select { + case <-ln.resetCh: + case <-time.After(time.Second): + t.Fatal("timed out waiting for ResetPeer") + } + if ln.resetCount != 1 { + t.Fatalf("ResetPeer calls = %d, want 1", ln.resetCount) + } +} + func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) { updates := 0 s := &Server{health: runtime.NewHealthTracker(func(control.Status) { updates++ })}