diff --git a/internal/client/client.go b/internal/client/client.go index e4a0f61..37c604b 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -88,6 +88,15 @@ func Run(roomURL, keyHex string, socksPort int) error { } log.Println("Connected to Telemost") + time.Sleep(100 * time.Millisecond) + + resetFrame := make([]byte, 4) + binary.BigEndian.PutUint16(resetFrame[0:2], 0xFFFF) + binary.BigEndian.PutUint16(resetFrame[2:4], 0xFFFF) + encrypted, _ := cipher.Encrypt(resetFrame) + peer.Send(encrypted) + log.Println("Sent reset signal to server") + go peer.WatchConnection(ctx) return c.runSOCKS5(socksPort) diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 4932913..ba86e9b 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -99,6 +99,17 @@ func (m *Multiplexer) HandleFrame(frame []byte) { sid := binary.BigEndian.Uint16(frame[0:2]) length := binary.BigEndian.Uint16(frame[2:4]) + if sid == 0xFFFF && length == 0xFFFF { + m.mu.Lock() + for _, stream := range m.streams { + stream.closed = true + } + m.streams = make(map[uint16]*Stream) + m.nextID = 1 + m.mu.Unlock() + return + } + if length == 0 { m.mu.Lock() if stream, exists := m.streams[sid]; exists { diff --git a/internal/server/server.go b/internal/server/server.go index ca9b88c..c7be0c8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/rand" + "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -85,13 +86,15 @@ func Run(roomURL, keyHex string) error { } s.connMu.Unlock() - s.mux.UpdateSendFunc(func(frame []byte) error { - encrypted, err := s.cipher.Encrypt(frame) - if err != nil { - return err - } - return s.peer.Send(encrypted) - }) + if dc != nil { + s.mux.UpdateSendFunc(func(frame []byte) error { + encrypted, err := s.cipher.Encrypt(frame) + if err != nil { + return err + } + return s.peer.Send(encrypted) + }) + } s.mux.Reset() @@ -116,6 +119,23 @@ func (s *Server) onData(data []byte) { return } + if len(plaintext) >= 4 { + sid := binary.BigEndian.Uint16(plaintext[0:2]) + length := binary.BigEndian.Uint16(plaintext[2:4]) + + if sid == 0xFFFF && length == 0xFFFF { + log.Println("Received reset signal from client - cleaning up") + s.connMu.Lock() + for sid, conn := range s.connections { + if conn != nil { + conn.Close() + } + delete(s.connections, sid) + } + s.connMu.Unlock() + } + } + s.mux.HandleFrame(plaintext) } @@ -174,6 +194,14 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { addr := fmt.Sprintf("%s:%d", req.Addr, req.Port) log.Printf("Connecting sid=%d to %s", sid, addr) + s.connMu.Lock() + if oldConn, exists := s.connections[sid]; exists && oldConn != nil { + log.Printf("Closing old connection for sid=%d", sid) + oldConn.Close() + delete(s.connections, sid) + } + s.connMu.Unlock() + conn, err := net.DialTimeout("tcp", addr, 10*time.Second) if err != nil { log.Printf("Connect failed sid=%d: %v", sid, err) @@ -192,6 +220,9 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { n, err := conn.Read(buf) if err != nil { s.mux.CloseStream(sid) + s.connMu.Lock() + delete(s.connections, sid) + s.connMu.Unlock() return } diff --git a/internal/telemost/peer.go b/internal/telemost/peer.go index 59c9698..0ec09f3 100644 --- a/internal/telemost/peer.go +++ b/internal/telemost/peer.go @@ -100,7 +100,11 @@ func (p *Peer) Connect(ctx context.Context) error { }) p.dc.OnClose(func() { - log.Println("DataChannel closed - triggering reconnect") + log.Println("DataChannel closed") + if p.onReconnect != nil { + log.Println("Calling reconnect callback for cleanup") + p.onReconnect(nil) + } select { case p.reconnectCh <- struct{}{}: default: