diff --git a/internal/client/client.go b/internal/client/client.go index d8e4f44..6012faf 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -134,12 +134,28 @@ func RunWithReady( select { case <-runCtx.Done(): + c.shutdown() return nil case err := <-errCh: return err } } +func (c *Client) shutdown() { + c.connMu.Lock() + for _, conn := range c.connections { + if conn != nil { + _ = conn.Close() + } + } + c.connMu.Unlock() + + for i, ln := range c.links { + logger.Infof("closing link %d", i) + _ = ln.Close() + } +} + func setupCipher(keyHex string) (*crypto.Cipher, error) { key, err := hex.DecodeString(keyHex) if err != nil { diff --git a/internal/server/server.go b/internal/server/server.go index ea4cc8a..7d7e983 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -110,6 +110,7 @@ func Run( err = s.runLoop(runCtx) + s.shutdown() s.wg.Wait() return err @@ -334,7 +335,6 @@ func (s *Server) runLoop(ctx context.Context) error { for { select { case <-ctx.Done(): - s.shutdown() return nil case <-ticker.C: s.processMuxStreams(ctx) @@ -351,6 +351,14 @@ func (s *Server) shutdown() { } s.connMu.Unlock() + s.pumpMu.Lock() + for _, conn := range s.streamPumps { + if conn != nil { + _ = conn.Close() + } + } + s.pumpMu.Unlock() + for i, tr := range s.links { logger.Infof("closing link %d", i) _ = tr.Close()