fix: handle graceful control shutdown and reconnects

This commit is contained in:
zarazaex69
2026-05-17 18:35:05 +03:00
parent 9a2bbfd44e
commit c6c301c058
9 changed files with 256 additions and 33 deletions

View File

@@ -334,13 +334,13 @@ func (f filteredWriter) Write(p []byte) (int, error) {
}
func configureLogging(debug bool) {
log.SetOutput(filteredWriter{w: os.Stderr})
if debug {
logger.SetVerbose(true)
return
}
_ = os.Setenv("PION_LOG_DISABLE", "all")
lksdk.SetLogger(protoLogger.GetDiscardLogger())
log.SetOutput(filteredWriter{w: os.Stderr})
}
func resolveDataDir(dataDir string) (string, error) {

View File

@@ -217,7 +217,7 @@ func (c *Client) bringUpLink(
return fmt.Errorf("smux client: %w", err)
}
control, sid, err := openControlStream(sess, c.deviceID, c.claims)
control, sid, err := openControlStream(ctx, sess, c.deviceID, c.claims)
if err != nil {
_ = sess.Close()
_ = c.conn.Close()
@@ -241,14 +241,16 @@ func (c *Client) bringUpLink(
// The stream stays open for the lifetime of the smux session and carries
// post-handshake control messages.
func openControlStream(
ctx context.Context,
sess *smux.Session,
deviceID string,
claims map[string]any,
) (*smux.Stream, string, error) {
return openControlStreamTimeout(sess, deviceID, claims, handshake.DefaultTimeout)
return openControlStreamTimeout(ctx, sess, deviceID, claims, handshake.DefaultTimeout)
}
func openControlStreamTimeout(
ctx context.Context,
sess *smux.Session,
deviceID string,
claims map[string]any,
@@ -258,11 +260,23 @@ func openControlStreamTimeout(
if err != nil {
return nil, "", fmt.Errorf("open control stream: %w", err)
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = stream.Close()
case <-done:
}
}()
defer close(done)
_ = stream.SetDeadline(time.Now().Add(timeout))
sid, err := handshake.Client(stream, deviceID, claims)
_ = stream.SetDeadline(time.Time{})
if err != nil {
_ = stream.Close()
if ctx.Err() != nil {
return nil, "", fmt.Errorf("handshake client: %w", ctx.Err())
}
return nil, "", fmt.Errorf("handshake client: %w", err)
}
return stream, sid, nil
@@ -316,6 +330,7 @@ func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context
c.recordReconnect()
logger.Infof("client reconnect reason=%s - tearing down smux session", reason)
c.resetLinkPeer()
// Install a fresh muxconn immediately so onData never hits nil while
// the old session is being torn down. tryReopenSession will swap it
@@ -371,6 +386,15 @@ func (c *Client) handleReconnect(ctx context.Context, cfg Config, cancel context
return false
}
func (c *Client) resetLinkPeer() {
c.sessMu.RLock()
ln := c.ln
c.sessMu.RUnlock()
if resetter, ok := ln.(interface{ ResetPeer() }); ok {
resetter.ResetPeer()
}
}
func (c *Client) tryReopenSession(
ctx context.Context,
cfg Config,
@@ -392,7 +416,7 @@ func (c *Client) tryReopenSession(
logger.Warnf("smux re-init failed (attempt %d): %v", attempt, err)
return false
}
control, sid, err := openControlStreamTimeout(sess, c.deviceID, c.claims, 2*time.Second)
control, sid, err := openControlStreamTimeout(ctx, sess, c.deviceID, c.claims, 2*time.Second)
if err != nil {
logger.Warnf("handshake on reconnect failed (attempt %d): %v", attempt, err)
_ = sess.Close()
@@ -486,6 +510,7 @@ func (c *Client) shutdown() {
c.conn = nil
c.sessMu.Unlock()
notifyControlClose(control)
if controlStop != nil {
controlStop()
}
@@ -503,6 +528,18 @@ func (c *Client) shutdown() {
}
}
func notifyControlClose(stream *smux.Stream) {
if stream == nil {
return
}
_ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := control.SendClose(stream); err == nil {
time.Sleep(200 * time.Millisecond)
}
_ = stream.SetWriteDeadline(time.Time{})
_ = stream.CloseWrite()
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
cipher, err := runtime.SetupCipher(keyHex)
if err != nil {

View File

@@ -48,7 +48,7 @@ func TestSetupCipherRejectsBadInput(t *testing.T) {
func TestSmuxConfig(t *testing.T) {
cfg := smuxConfig(0)
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
if cfg.Version != 2 || cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig(0) = %+v", cfg)
}
capped := smuxConfig(4096)
@@ -491,19 +491,59 @@ func TestSendConnectRequestRejectsBadAck(t *testing.T) {
}
}
type closerLinkStub struct {
closed bool
func TestOpenControlStreamStopsOnContextCancel(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig(0))
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
defer func() { _ = serverSess.Close() }()
clientSess, err := smux.Client(b, smuxConfig(0))
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
defer func() { _ = clientSess.Close() }()
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, _, err := openControlStreamTimeout(ctx, clientSess, "dev", nil, time.Hour)
errCh <- err
}()
time.Sleep(20 * time.Millisecond)
cancel()
select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("openControlStreamTimeout() error = %v, want context.Canceled", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for context cancellation")
}
}
func (s *closerLinkStub) Connect(context.Context) error { return nil }
func (s *closerLinkStub) Send([]byte) error { return nil }
func (s *closerLinkStub) Close() error { s.closed = true; return nil }
func (s *closerLinkStub) SetReconnectCallback(func()) {}
func (s *closerLinkStub) SetShouldReconnect(func() bool) {}
func (s *closerLinkStub) SetEndedCallback(func(string)) {}
func (s *closerLinkStub) WatchConnection(context.Context) {}
func (s *closerLinkStub) CanSend() bool { return true }
func (s *closerLinkStub) Features() transport.Features { return transport.Features{} }
type closerLinkStub struct {
closed bool
resetCount int
}
func (s *closerLinkStub) Connect(context.Context) error { return nil }
func (s *closerLinkStub) Send([]byte) error { return nil }
func (s *closerLinkStub) Close() error { s.closed = true; return nil }
func (s *closerLinkStub) SetReconnectCallback(func()) {}
func (s *closerLinkStub) SetShouldReconnect(func() bool) {}
func (s *closerLinkStub) SetEndedCallback(func(string)) {}
func (s *closerLinkStub) WatchConnection(context.Context) {}
func (s *closerLinkStub) CanSend() bool { return true }
func (s *closerLinkStub) Features() transport.Features { return transport.Features{} }
func (s *closerLinkStub) ResetPeer() { s.resetCount++ }
func TestOnDataWithNilConn(_ *testing.T) {
c := &Client{}
@@ -527,6 +567,15 @@ func TestShutdownClosesLinkAndConn(t *testing.T) {
}
}
func TestResetLinkPeer(t *testing.T) {
ln := &closerLinkStub{}
c := &Client{ln: ln}
c.resetLinkPeer()
if ln.resetCount != 1 {
t.Fatalf("ResetPeer calls = %d, want 1", ln.resetCount)
}
}
//nolint:cyclop // integration-style control loop test needs setup and async assertions together
func TestStartControlLoopReportsPong(t *testing.T) {
a, b := net.Pipe()

View File

@@ -44,11 +44,15 @@ const (
TypePing MsgType = "CONTROL_PING"
// TypePong replies to a ping with the same sequence and timestamp.
TypePong MsgType = "CONTROL_PONG"
// TypeClose tells the peer this control session is intentionally closing.
TypeClose MsgType = "CONTROL_CLOSE"
)
var (
// ErrUnhealthy is returned when the stream misses too many pong replies.
ErrUnhealthy = errors.New("control stream unhealthy")
// ErrClosedByPeer is returned when the peer gracefully closes the control session.
ErrClosedByPeer = errors.New("control stream closed by peer")
// ErrProtocolVersion is returned when the peer announces an incompatible version.
ErrProtocolVersion = errors.New("incompatible control protocol version")
// ErrUnexpectedMessage is returned for unknown or malformed control message types.
@@ -184,6 +188,8 @@ func (s *state) readLoop(ctx context.Context) error {
}
case TypePong:
s.handlePong(msg)
case TypeClose:
return ErrClosedByPeer
default:
return fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
@@ -302,12 +308,17 @@ func parseMessage(raw []byte) (Message, error) {
return Message{}, fmt.Errorf("%w: peer v%d, local v%d",
ErrProtocolVersion, msg.Version, ProtoVersion)
}
if msg.Type != TypePing && msg.Type != TypePong {
if msg.Type != TypePing && msg.Type != TypePong && msg.Type != TypeClose {
return Message{}, fmt.Errorf("%w: got %q", ErrUnexpectedMessage, msg.Type)
}
return msg, nil
}
// SendClose sends a best-effort graceful close notification on the control stream.
func SendClose(w io.Writer) error {
return writeFrame(w, Message{Version: ProtoVersion, Type: TypeClose})
}
func writeFrame(w io.Writer, msg Message) error {
if err := framing.WriteJSON(w, msg, MaxMessageSize); err != nil {
return fmt.Errorf("control: %w", err)

View File

@@ -124,6 +124,26 @@ func TestRunRejectsBadProtocolVersion(t *testing.T) {
}
}
func TestRunStopsOnPeerClose(t *testing.T) {
a, b := controlPair(t)
errCh := make(chan error, 1)
go func() {
errCh <- Run(context.Background(), a, Config{Interval: time.Hour})
}()
if err := SendClose(b); err != nil {
t.Fatalf("SendClose() error = %v", err)
}
select {
case err := <-errCh:
if !errors.Is(err, ErrClosedByPeer) {
t.Fatalf("Run() error = %v, want ErrClosedByPeer", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for peer close")
}
}
func TestReadFrameRejectsTooLarge(t *testing.T) {
a, b := controlPair(t)
go func() {

View File

@@ -49,7 +49,7 @@ func SetupCipher(keyHex string) (*crypto.Cipher, error) {
func SmuxConfig(maxWirePayload int) *smux.Config {
cfg := smux.DefaultConfig()
cfg.Version = 2
cfg.KeepAliveDisabled = true
cfg.KeepAliveDisabled = false
cfg.MaxFrameSize = 32768
if maxWirePayload > crypto.WireOverhead {
maxFrameSize := maxWirePayload - crypto.WireOverhead
@@ -60,7 +60,7 @@ func SmuxConfig(maxWirePayload int) *smux.Config {
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
cfg.MaxStreamBuffer = 1024 * 1024
cfg.KeepAliveInterval = 10 * time.Second
cfg.KeepAliveTimeout = 60 * time.Second
cfg.KeepAliveTimeout = 30 * time.Second
return cfg
}
@@ -76,9 +76,9 @@ func MaxPayload(tr transport.Transport) int {
// Server and client both embed a HealthTracker to avoid open-coding the
// same record* methods on both sides.
type HealthTracker struct {
mu sync.RWMutex
status control.Status
notify func(control.Status)
mu sync.RWMutex
status control.Status
notify func(control.Status)
}
// NewHealthTracker creates a HealthTracker that publishes the latest

View File

@@ -37,6 +37,10 @@ func TestSmuxConfigDefault(t *testing.T) {
if cfg.Version != 2 || cfg.MaxFrameSize != 32768 {
t.Fatalf("SmuxConfig(0) = %+v", cfg)
}
if cfg.KeepAliveDisabled || cfg.KeepAliveInterval != 10*time.Second ||
cfg.KeepAliveTimeout != 30*time.Second {
t.Fatalf("SmuxConfig(0) keepalive = %+v", cfg)
}
}
func TestSmuxConfigShrinks(t *testing.T) {

View File

@@ -60,6 +60,7 @@ type Server struct {
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
reinstallMu sync.Mutex
@@ -307,10 +308,12 @@ func (s *Server) reinstallSession(dead *smux.Session) {
}
oldSess := s.session
oldConn := s.conn
oldControl := s.controlStrm
oldControlStop := s.controlStop
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.controlStrm = nil
s.controlStop = nil
s.sessionID = ""
s.deviceID = ""
@@ -325,6 +328,9 @@ func (s *Server) reinstallSession(dead *smux.Session) {
if oldConn != nil {
_ = oldConn.Close()
}
if oldControl != nil {
_ = oldControl.Close()
}
if oldSID != "" {
s.onClose(oldSID, "reconnect")
}
@@ -334,15 +340,18 @@ func (s *Server) closeSession() {
s.sessMu.Lock()
sess := s.session
conn := s.conn
control := s.controlStrm
controlStop := s.controlStop
s.session = nil
s.conn = nil
s.controlStrm = nil
s.controlStop = nil
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
notifyControlClose(control)
if controlStop != nil {
controlStop()
}
@@ -357,6 +366,18 @@ func (s *Server) closeSession() {
}
}
func notifyControlClose(stream *smux.Stream) {
if stream == nil {
return
}
_ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := control.SendClose(stream); err == nil {
time.Sleep(200 * time.Millisecond)
}
_ = stream.SetWriteDeadline(time.Time{})
_ = stream.CloseWrite()
}
func (s *Server) onData(data []byte) {
s.sessMu.RLock()
conn := s.conn
@@ -474,6 +495,7 @@ func (s *Server) resetLinkPeer() {
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
s.sessMu.Lock()
s.controlStrm = stream
s.controlStop = stop
s.sessMu.Unlock()
@@ -519,6 +541,7 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea
}
s.recordReconnect()
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
s.resetLinkPeer()
s.reinstallSession(sess)
}()
}

View File

@@ -49,7 +49,7 @@ func TestSetupCipherRejectsBadInput(t *testing.T) {
func TestSmuxConfig(t *testing.T) {
cfg := smuxConfig(0)
if cfg.Version != 2 || !cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
if cfg.Version != 2 || cfg.KeepAliveDisabled || cfg.MaxFrameSize != 32768 || cfg.MaxReceiveBuffer != 16*1024*1024 {
t.Fatalf("smuxConfig(0) = %+v", cfg)
}
capped := smuxConfig(4096)
@@ -211,18 +211,29 @@ func TestOnDataWithNilConn(_ *testing.T) {
}
type serverLinkStub struct {
closed bool
closed bool
resetCount int
resetCh chan struct{}
}
func (s *serverLinkStub) Connect(context.Context) error { return nil }
func (s *serverLinkStub) Send([]byte) error { return nil }
func (s *serverLinkStub) Close() error { s.closed = true; return nil }
func (s *serverLinkStub) SetReconnectCallback(func()) {}
func (s *serverLinkStub) SetShouldReconnect(func() bool) {}
func (s *serverLinkStub) SetEndedCallback(func(string)) {}
func (s *serverLinkStub) WatchConnection(context.Context) {}
func (s *serverLinkStub) CanSend() bool { return true }
func (s *serverLinkStub) Features() transport.Features { return transport.Features{} }
func (s *serverLinkStub) Connect(context.Context) error { return nil }
func (s *serverLinkStub) Send([]byte) error { return nil }
func (s *serverLinkStub) Close() error { s.closed = true; return nil }
func (s *serverLinkStub) SetReconnectCallback(func()) {}
func (s *serverLinkStub) SetShouldReconnect(func() bool) {}
func (s *serverLinkStub) SetEndedCallback(func(string)) {}
func (s *serverLinkStub) WatchConnection(context.Context) {}
func (s *serverLinkStub) CanSend() bool { return true }
func (s *serverLinkStub) Features() transport.Features { return transport.Features{} }
func (s *serverLinkStub) ResetPeer() {
s.resetCount++
if s.resetCh != nil {
select {
case s.resetCh <- struct{}{}:
default:
}
}
}
func TestShutdownClosesLinkAndConn(t *testing.T) {
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
@@ -463,6 +474,74 @@ func TestStartControlLoopReportsPong(t *testing.T) {
}
}
func TestStartControlLoopResetsPeerBeforeReinstall(t *testing.T) {
a, b := net.Pipe()
defer func() {
_ = a.Close()
_ = b.Close()
}()
serverSess, err := smux.Server(a, smuxConfig(0))
if err != nil {
t.Fatalf("smux.Server() error = %v", err)
}
clientSess, err := smux.Client(b, smuxConfig(0))
if err != nil {
t.Fatalf("smux.Client() error = %v", err)
}
serverStreamCh := make(chan *smux.Stream, 1)
go func() {
stream, err := serverSess.AcceptStream()
if err == nil {
serverStreamCh <- stream
}
}()
clientStream, err := clientSess.OpenStream()
if err != nil {
t.Fatalf("OpenStream() error = %v", err)
}
serverStream := <-serverStreamCh
cipher, err := cryptopkg.NewCipher("01234567890123456789012345678901")
if err != nil {
t.Fatalf("NewCipher() error = %v", err)
}
ln := &serverLinkStub{resetCh: make(chan struct{}, 1)}
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
ln: ln,
cipher: cipher,
conn: muxconn.New(ln, cipher),
session: serverSess,
health: runtime.NewHealthTracker(nil),
liveness: control.Config{
Interval: time.Hour,
Timeout: time.Hour,
Failures: 1,
},
}
defer func() {
cancel()
s.shutdown()
s.wg.Wait()
_ = clientSess.Close()
}()
s.startControlLoop(ctx, serverSess, serverStream)
_ = clientStream.Close()
select {
case <-ln.resetCh:
case <-time.After(time.Second):
t.Fatal("timed out waiting for ResetPeer")
}
if ln.resetCount != 1 {
t.Fatalf("ResetPeer calls = %d, want 1", ln.resetCount)
}
}
func TestStatusRecordsReconnectAndUnhealthy(t *testing.T) {
updates := 0
s := &Server{health: runtime.NewHealthTracker(func(control.Status) { updates++ })}