From 6d2f594bb3ee991834dcefd22b76ec0a86a98ace Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Thu, 9 Apr 2026 18:42:15 +0300 Subject: [PATCH] feat(main,client,server): Add graceful shutdown with context propagation --- cmd/olcrtc/main.go | 34 ++++++++++++++++++++++++++++------ internal/client/client.go | 21 +++++++++++++++------ internal/server/server.go | 22 ++++++++++++++++++---- 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/cmd/olcrtc/main.go b/cmd/olcrtc/main.go index 9411471..d5617d8 100644 --- a/cmd/olcrtc/main.go +++ b/cmd/olcrtc/main.go @@ -1,9 +1,13 @@ package main import ( + "context" "flag" "log" + "os" + "os/signal" "path/filepath" + "syscall" "github.com/openlibrecommunity/olcrtc/internal/client" "github.com/openlibrecommunity/olcrtc/internal/names" @@ -55,13 +59,31 @@ func main() { roomURL := "https://telemost.yandex.ru/j/" + roomID - switch mode { - case "srv": - if err := server.Run(roomURL, keyHex); err != nil { - log.Fatal(err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + errCh := make(chan error, 1) + + go func() { + switch mode { + case "srv": + errCh <- server.Run(ctx, roomURL, keyHex) + case "cnc": + errCh <- client.Run(ctx, roomURL, keyHex, socksPort) } - case "cnc": - if err := client.Run(roomURL, keyHex, socksPort); err != nil { + }() + + select { + case <-sigCh: + log.Println("Shutting down gracefully...") + cancel() + <-errCh + log.Println("Shutdown complete") + case err := <-errCh: + if err != nil { log.Fatal(err) } } diff --git a/internal/client/client.go b/internal/client/client.go index 5304ca5..670e6fc 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -26,7 +26,7 @@ type Client struct { clientID uint32 } -func Run(roomURL, keyHex string, socksPort int) error { +func Run(ctx context.Context, roomURL, keyHex string, socksPort int) error { var key []byte var err error @@ -94,7 +94,6 @@ func Run(roomURL, keyHex string, socksPort int) error { }) log.Println("Connecting to Telemost...") - ctx := context.Background() if err := peer.Connect(ctx); err != nil { return err } @@ -112,7 +111,7 @@ func Run(roomURL, keyHex string, socksPort int) error { go peer.WatchConnection(ctx) - return c.runSOCKS5(socksPort) + return c.runSOCKS5(ctx, socksPort) } func (c *Client) onData(data []byte) { @@ -124,7 +123,7 @@ func (c *Client) onData(data []byte) { c.mux.HandleFrame(plaintext) } -func (c *Client) runSOCKS5(port int) error { +func (c *Client) runSOCKS5(ctx context.Context, port int) error { listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port)) if err != nil { return err @@ -133,11 +132,21 @@ func (c *Client) runSOCKS5(port int) error { log.Printf("SOCKS5 proxy listening on 0.0.0.0:%d", port) + go func() { + <-ctx.Done() + listener.Close() + }() + for { conn, err := listener.Accept() if err != nil { - log.Printf("Accept error: %v", err) - continue + select { + case <-ctx.Done(): + return nil + default: + log.Printf("Accept error: %v", err) + continue + } } go c.handleSOCKS5(conn) diff --git a/internal/server/server.go b/internal/server/server.go index 0994b60..6d4f3c0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -33,7 +33,7 @@ type ConnectRequest struct { Port int `json:"port"` } -func Run(roomURL, keyHex string) error { +func Run(ctx context.Context, roomURL, keyHex string) error { var key []byte var err error @@ -110,7 +110,6 @@ func Run(roomURL, keyHex string) error { }) log.Println("Connecting to Telemost...") - ctx := context.Background() if err := peer.Connect(ctx); err != nil { return err } @@ -118,7 +117,7 @@ func Run(roomURL, keyHex string) error { go peer.WatchConnection(ctx) - return s.run() + return s.run(ctx) } func (s *Server) onData(data []byte) { @@ -151,8 +150,23 @@ func (s *Server) onData(data []byte) { s.mux.HandleFrame(plaintext) } -func (s *Server) run() error { +func (s *Server) run(ctx context.Context) error { for { + select { + case <-ctx.Done(): + log.Println("Server shutting down...") + s.connMu.Lock() + for _, conn := range s.connections { + if conn != nil { + conn.Close() + } + } + s.connMu.Unlock() + s.peer.Close() + return nil + default: + } + sids := s.mux.GetStreams() for _, sid := range sids {