diff --git a/internal/client/client.go b/internal/client/client.go index 94d882b..3729588 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -247,6 +247,7 @@ func (c *Client) handleSOCKS5(conn net.Conn) { conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) done := make(chan struct{}) + streamClosed := make(chan struct{}) go func() { defer close(done) @@ -257,11 +258,14 @@ func (c *Client) handleSOCKS5(conn net.Conn) { c.mux.CloseStream(sid) return } - c.mux.SendData(sid, buf[:n]) + if err := c.mux.SendData(sid, buf[:n]); err != nil { + return + } } }() go func() { + defer close(streamClosed) ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() @@ -272,7 +276,9 @@ func (c *Client) handleSOCKS5(conn net.Conn) { case <-ticker.C: data := c.mux.ReadStream(sid) if len(data) > 0 { - conn.Write(data) + if _, err := conn.Write(data); err != nil { + return + } } if c.mux.StreamClosed(sid) { @@ -282,5 +288,9 @@ func (c *Client) handleSOCKS5(conn net.Conn) { } }() - <-done + select { + case <-done: + case <-streamClosed: + } +} } diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 747c21a..9963567 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -18,21 +18,23 @@ type Stream struct { } type Multiplexer struct { - streams map[uint16]*Stream - nextID uint16 - clientID uint32 - onSend func([]byte) error - mu sync.RWMutex - maxStreams int + streams map[uint16]*Stream + nextID uint16 + clientID uint32 + onSend func([]byte) error + mu sync.RWMutex + maxStreams int + maxBufferSize int } func New(clientID uint32, onSend func([]byte) error) *Multiplexer { return &Multiplexer{ - streams: make(map[uint16]*Stream), - nextID: 1, - clientID: clientID, - onSend: onSend, - maxStreams: 10000, + streams: make(map[uint16]*Stream), + nextID: 1, + clientID: clientID, + onSend: onSend, + maxStreams: 10000, + maxBufferSize: 1024 * 1024, } } @@ -40,15 +42,21 @@ func (m *Multiplexer) OpenStream() uint16 { m.mu.Lock() defer m.mu.Unlock() - sid := m.nextID - m.nextID++ - - m.streams[sid] = &Stream{ - ID: sid, - recvBuf: make([]byte, 0), + for { + sid := m.nextID + m.nextID++ + if m.nextID == 0 { + m.nextID = 1 + } + + if _, exists := m.streams[sid]; !exists { + m.streams[sid] = &Stream{ + ID: sid, + recvBuf: make([]byte, 0), + } + return sid + } } - - return sid } func (m *Multiplexer) SendData(sid uint16, data []byte) error { @@ -135,10 +143,11 @@ func (m *Multiplexer) HandleFrame(frame []byte) { data := frame[8 : 8+length] m.mu.Lock() + defer m.mu.Unlock() + stream, exists := m.streams[sid] if !exists { if len(m.streams) >= m.maxStreams { - m.mu.Unlock() return } stream = &Stream{ @@ -152,8 +161,13 @@ func (m *Multiplexer) HandleFrame(frame []byte) { stream.recvBuf = make([]byte, 0) stream.closed = false } + + if len(stream.recvBuf)+len(data) > m.maxBufferSize { + stream.closed = true + return + } + stream.recvBuf = append(stream.recvBuf, data...) - m.mu.Unlock() } func (m *Multiplexer) ReadStream(sid uint16) []byte { diff --git a/internal/server/server.go b/internal/server/server.go index fc78124..0621be9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -185,16 +185,13 @@ func (s *Server) run() error { } if s.mux.StreamClosed(sid) { - s.connMu.RLock() + s.connMu.Lock() conn, exists := s.connections[sid] - s.connMu.RUnlock() - if exists && conn != nil { conn.Close() - s.connMu.Lock() delete(s.connections, sid) - s.connMu.Unlock() } + s.connMu.Unlock() } } } @@ -218,7 +215,7 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) { conn, err := net.DialTimeout("tcp", addr, 10*time.Second) if err != nil { log.Printf("Connect failed sid=%d: %v", sid, err) - s.mux.CloseStream(sid) + go s.mux.CloseStream(sid) return } diff --git a/internal/telemost/peer.go b/internal/telemost/peer.go index 9ee1920..d8a4cc3 100644 --- a/internal/telemost/peer.go +++ b/internal/telemost/peer.go @@ -424,15 +424,17 @@ func (p *Peer) sendLeave() { p.wsMu.Lock() defer p.wsMu.Unlock() - if p.ws != nil { - leave := map[string]interface{}{ - "uid": uuid.New().String(), - "leave": map[string]interface{}{}, - } - if err := p.ws.WriteJSON(leave); err == nil { - log.Println("Sent leave message to server") - time.Sleep(200 * time.Millisecond) - } + if p.ws == nil { + return + } + + leave := map[string]interface{}{ + "uid": uuid.New().String(), + "leave": map[string]interface{}{}, + } + if err := p.ws.WriteJSON(leave); err == nil { + log.Println("Sent leave message to server") + time.Sleep(200 * time.Millisecond) } }