refactor(server): replace context with done channel for stop signal

This commit is contained in:
zarazaex69
2026-05-18 08:14:39 +03:00
parent 7ca82dfa74
commit 535c3b75d1

View File

@@ -79,6 +79,8 @@ type Server struct {
socksProxyPort int
liveness control.Config
health *runtime.HealthTracker
done chan struct{}
doneOnce sync.Once
}
type peerSession struct {
@@ -167,6 +169,7 @@ func Run(ctx context.Context, cfg Config) error {
liveness: cfg.Liveness,
health: runtime.NewHealthTracker(cfg.OnHealth),
peerSessions: make(map[string]*peerSession),
done: make(chan struct{}),
}
s.setupResolver()
@@ -372,10 +375,10 @@ func (s *Server) closeSession() {
s.deviceID = ""
s.sessMu.Unlock()
notifyControlClose(control)
if controlStop != nil {
controlStop()
}
notifyControlClose(control)
if sess != nil {
_ = sess.Close()
}
@@ -401,10 +404,10 @@ func (s *Server) removePeerSession(peerID, reason string) {
}
func (s *Server) closePeerSession(ps *peerSession, reason string) {
notifyControlClose(ps.controlStrm)
if ps.controlStop != nil {
ps.controlStop()
}
notifyControlClose(ps.controlStrm)
if ps.session != nil {
_ = ps.session.Close()
}
@@ -472,7 +475,7 @@ func (s *Server) getPeerSession(peerID string) *peerSession {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.servePeer(context.Background(), ps)
s.servePeer(ps)
}()
return ps
}
@@ -587,18 +590,18 @@ func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
return true
}
func (s *Server) servePeer(ctx context.Context, ps *peerSession) {
if !s.acceptPeerHandshake(ctx, ps) {
func (s *Server) servePeer(ps *peerSession) {
if !s.acceptPeerHandshake(ps) {
s.removePeerSession(ps.peerID, "closed")
return
}
for {
if contextDone(ctx) {
if s.stopping() {
return
}
stream, err := ps.session.AcceptStream()
if err != nil {
if contextDone(ctx) {
if s.stopping() {
return
}
logger.Debugf("AcceptStream(peer=%s) returned %v - closing peer session", ps.peerID, err)
@@ -608,15 +611,15 @@ func (s *Server) servePeer(ctx context.Context, ps *peerSession) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleStream(ctx, stream, ps.sessionID)
s.handleStream(context.Background(), stream, ps.sessionID)
}()
}
}
func (s *Server) acceptPeerHandshake(ctx context.Context, ps *peerSession) bool {
func (s *Server) acceptPeerHandshake(ps *peerSession) bool {
stream, err := ps.session.AcceptStream()
if err != nil {
if !contextDone(ctx) {
if !s.stopping() {
logger.Debugf("AcceptStream(control peer=%s) returned %v", ps.peerID, err)
}
return false
@@ -635,7 +638,7 @@ func (s *Server) acceptPeerHandshake(ctx context.Context, ps *peerSession) bool
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s peer=%s)", sid, hello.DeviceID, ps.peerID)
s.startPeerControlLoop(ctx, ps, stream)
s.startPeerControlLoop(ps, stream)
return true
}
@@ -702,8 +705,8 @@ func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, strea
}()
}
func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
func (s *Server) startPeerControlLoop(ps *peerSession, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(context.Background())
ps.controlStop = stop
liveness := s.liveness
@@ -739,7 +742,7 @@ func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stre
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
if controlCtx.Err() != nil || s.stopping() {
return
}
if err != nil {
@@ -750,6 +753,15 @@ func (s *Server) startPeerControlLoop(ctx context.Context, ps *peerSession, stre
}()
}
func (s *Server) stopping() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Status returns the latest server-side control health snapshot.
func (s *Server) Status() control.Status {
return s.health.Status()
@@ -762,6 +774,9 @@ func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(miss
func (s *Server) recordReconnect() { s.health.RecordReconnect() }
func (s *Server) shutdown() {
if s.done != nil {
s.doneOnce.Do(func() { close(s.done) })
}
s.closeSession()
if s.ln != nil {
_ = s.ln.Close()