From 535c3b75d10b2e95456428b38e5c5f1881f4dbc3 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Mon, 18 May 2026 08:14:39 +0300 Subject: [PATCH] refactor(server): replace context with done channel for stop signal --- internal/server/server.go | 43 ++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 63a667f..587e1df 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -79,6 +79,8 @@ type Server struct { socksProxyPort int liveness control.Config health *runtime.HealthTracker + done chan struct{} + doneOnce sync.Once } type peerSession struct { @@ -167,6 +169,7 @@ func Run(ctx context.Context, cfg Config) error { liveness: cfg.Liveness, health: runtime.NewHealthTracker(cfg.OnHealth), peerSessions: make(map[string]*peerSession), + done: make(chan struct{}), } s.setupResolver() @@ -372,10 +375,10 @@ func (s *Server) closeSession() { s.deviceID = "" s.sessMu.Unlock() - notifyControlClose(control) if controlStop != nil { controlStop() } + notifyControlClose(control) if sess != nil { _ = sess.Close() } @@ -401,10 +404,10 @@ func (s *Server) removePeerSession(peerID, reason string) { } func (s *Server) closePeerSession(ps *peerSession, reason string) { - notifyControlClose(ps.controlStrm) if ps.controlStop != nil { ps.controlStop() } + notifyControlClose(ps.controlStrm) if ps.session != nil { _ = ps.session.Close() } @@ -472,7 +475,7 @@ func (s *Server) getPeerSession(peerID string) *peerSession { s.wg.Add(1) go func() { defer s.wg.Done() - s.servePeer(context.Background(), ps) + s.servePeer(ps) }() return ps } @@ -587,18 +590,18 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool { return true } -func (s *Server) servePeer(ctx context.Context, ps *peerSession) { - if !s.acceptPeerHandshake(ctx, ps) { +func (s *Server) servePeer(ps *peerSession) { + if !s.acceptPeerHandshake(ps) { s.removePeerSession(ps.peerID, "closed") return } for { - if contextDone(ctx) { + if s.stopping() { return } stream, err := ps.session.AcceptStream() if err != nil { - if contextDone(ctx) { + if s.stopping() { return } logger.Debugf("AcceptStream(peer=%s) returned %v - closing peer session", ps.peerID, err) @@ -608,15 +611,15 @@ func (s *Server) servePeer(ctx context.Context, ps *peerSession) { s.wg.Add(1) go func() { defer s.wg.Done() - s.handleStream(ctx, stream, ps.sessionID) + s.handleStream(context.Background(), stream, ps.sessionID) }() } } -func (s *Server) acceptPeerHandshake(ctx context.Context, ps *peerSession) bool { +func (s *Server) acceptPeerHandshake(ps *peerSession) bool { stream, err := ps.session.AcceptStream() if err != nil { - if !contextDone(ctx) { + if !s.stopping() { logger.Debugf("AcceptStream(control peer=%s) returned %v", ps.peerID, err) } return false @@ -635,7 +638,7 @@ func (s *Server) acceptPeerHandshake(ctx context.Context, ps *peerSession) bool s.recordSession(sid) s.onOpen(sid, hello.DeviceID, hello.Claims) logger.Infof("session %s opened (device=%s peer=%s)", sid, hello.DeviceID, ps.peerID) - s.startPeerControlLoop(ctx, ps, stream) + s.startPeerControlLoop(ps, stream) return true } @@ -702,8 +705,8 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea }() } -func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stream *smux.Stream) { - controlCtx, stop := context.WithCancel(ctx) +func (s *Server) startPeerControlLoop(ps *peerSession, stream *smux.Stream) { + controlCtx, stop := context.WithCancel(context.Background()) ps.controlStop = stop liveness := s.liveness @@ -739,7 +742,7 @@ func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stre defer s.wg.Done() defer func() { _ = stream.Close() }() err := control.Run(controlCtx, stream, liveness) - if controlCtx.Err() != nil || ctx.Err() != nil { + if controlCtx.Err() != nil || s.stopping() { return } if err != nil { @@ -750,6 +753,15 @@ func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stre }() } +func (s *Server) stopping() bool { + select { + case <-s.done: + return true + default: + return false + } +} + // Status returns the latest server-side control health snapshot. func (s *Server) Status() control.Status { return s.health.Status() @@ -762,6 +774,9 @@ func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(miss func (s *Server) recordReconnect() { s.health.RecordReconnect() } func (s *Server) shutdown() { + if s.done != nil { + s.doneOnce.Do(func() { close(s.done) }) + } s.closeSession() if s.ln != nil { _ = s.ln.Close()