diff --git a/internal/server/server.go b/internal/server/server.go index 3ef25c2..594d9fe 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -39,6 +39,7 @@ type Server struct { conn *muxconn.Conn session *smux.Session sessMu sync.RWMutex + reinstallMu sync.Mutex wg sync.WaitGroup dnsServer string resolver *net.Resolver @@ -229,7 +230,21 @@ func (s *Server) installSession() { func (s *Server) handleReconnect() { logger.Infof("server link reconnect — tearing down smux session") + s.sessMu.RLock() + current := s.session + s.sessMu.RUnlock() + s.reinstallSession(current) +} + +func (s *Server) reinstallSession(dead *smux.Session) { + s.reinstallMu.Lock() + defer s.reinstallMu.Unlock() + s.sessMu.Lock() + if s.session != dead { + s.sessMu.Unlock() + return + } if s.session != nil { _ = s.session.Close() s.session = nil @@ -281,8 +296,8 @@ func (s *Server) serve(ctx context.Context) { return default: } - logger.Infof("AcceptStream returned %v — waiting for new session", err) - s.waitForNewSession(ctx, sess) + logger.Infof("AcceptStream returned %v — reinstalling session", err) + s.reinstallSession(sess) continue } @@ -294,22 +309,6 @@ func (s *Server) serve(ctx context.Context) { } } -func (s *Server) waitForNewSession(ctx context.Context, dead *smux.Session) { - for { - select { - case <-ctx.Done(): - return - case <-time.After(50 * time.Millisecond): - } - s.sessMu.RLock() - current := s.session - s.sessMu.RUnlock() - if current != dead { - return - } - } -} - func (s *Server) shutdown() { s.sessMu.Lock() if s.session != nil {