feat(client,server,mux): Add input validation and improve connection handling

This commit is contained in:
zarazaex69
2026-04-09 17:51:34 +03:00
parent 12c8241987
commit 562572fe8a
5 changed files with 78 additions and 30 deletions

View File

@@ -42,6 +42,10 @@ func main() {
log.Fatal("Room ID required")
}
if mode != "srv" && mode != "cnc" {
log.Fatal("Specify -mode srv or -mode cnc")
}
namesPath := filepath.Join(dataDir, "names")
surnamesPath := filepath.Join(dataDir, "surnames")
@@ -60,7 +64,5 @@ func main() {
if err := client.Run(roomURL, keyHex, socksPort); err != nil {
log.Fatal(err)
}
default:
log.Fatal("Specify -mode srv or -mode cnc")
}
}

View File

@@ -41,9 +41,17 @@ func Run(roomURL, keyHex string, socksPort int) error {
if err != nil {
return err
}
if len(key) != 32 {
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
}
cipher, err := crypto.NewCipher(string(key))
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return err
}
@@ -95,9 +103,9 @@ func Run(roomURL, keyHex string, socksPort int) error {
time.Sleep(100 * time.Millisecond)
resetFrame := make([]byte, 8)
binary.BigEndian.PutUint16(resetFrame[0:2], 0xFFFF)
binary.BigEndian.PutUint16(resetFrame[2:4], 0xFFFF)
binary.BigEndian.PutUint32(resetFrame[4:8], c.clientID)
binary.BigEndian.PutUint32(resetFrame[0:4], c.clientID)
binary.BigEndian.PutUint16(resetFrame[4:6], 0xFFFF)
binary.BigEndian.PutUint16(resetFrame[6:8], 0xFFFF)
encrypted, _ := cipher.Encrypt(resetFrame)
peer.Send(encrypted)
log.Printf("Sent reset signal to server (clientID=%d)", c.clientID)
@@ -205,7 +213,32 @@ func (c *Client) handleSOCKS5(conn net.Conn) {
reqData, _ := json.Marshal(req)
c.mux.SendData(sid, reqData)
time.Sleep(500 * time.Millisecond)
connected := make(chan bool, 1)
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
go func() {
for i := 0; i < 100; i++ {
time.Sleep(50 * time.Millisecond)
data := c.mux.ReadStream(sid)
if len(data) > 0 || c.mux.StreamClosed(sid) {
connected <- len(data) > 0
return
}
}
connected <- false
}()
select {
case success := <-connected:
if !success {
conn.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0})
return
}
case <-timeout.C:
conn.Write([]byte{5, 4, 0, 1, 0, 0, 0, 0, 0, 0})
return
}
conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0})

View File

@@ -18,19 +18,21 @@ type Stream struct {
}
type Multiplexer struct {
streams map[uint16]*Stream
nextID uint16
clientID uint32
onSend func([]byte) error
mu sync.RWMutex
streams map[uint16]*Stream
nextID uint16
clientID uint32
onSend func([]byte) error
mu sync.RWMutex
maxStreams int
}
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
return &Multiplexer{
streams: make(map[uint16]*Stream),
nextID: 1,
clientID: clientID,
onSend: onSend,
streams: make(map[uint16]*Stream),
nextID: 1,
clientID: clientID,
onSend: onSend,
maxStreams: 10000,
}
}
@@ -135,6 +137,10 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
m.mu.Lock()
stream, exists := m.streams[sid]
if !exists {
if len(m.streams) >= m.maxStreams {
m.mu.Unlock()
return
}
stream = &Stream{
ID: sid,
ClientID: clientID,

View File

@@ -2,10 +2,9 @@ package names
import (
"bufio"
"math/rand"
"math/rand/v2"
"os"
"strings"
"time"
)
var (
@@ -13,10 +12,6 @@ var (
lastNames []string
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func loadNames(path string) ([]string, error) {
file, err := os.Open(path)
if err != nil {
@@ -56,8 +51,8 @@ func Generate() string {
return "Unknown User"
}
first := firstNames[rand.Intn(len(firstNames))]
last := lastNames[rand.Intn(len(lastNames))]
first := firstNames[rand.IntN(len(firstNames))]
last := lastNames[rand.IntN(len(lastNames))]
return first + " " + last
}

View File

@@ -48,9 +48,17 @@ func Run(roomURL, keyHex string) error {
if err != nil {
return err
}
if len(key) != 32 {
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
}
cipher, err := crypto.NewCipher(string(key))
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return err
}
@@ -199,7 +207,8 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
log.Printf("Connecting sid=%d to %s", sid, addr)
s.connMu.Lock()
if oldConn, exists := s.connections[sid]; exists && oldConn != nil {
oldConn, exists := s.connections[sid]
if exists && oldConn != nil {
log.Printf("Closing old connection for sid=%d", sid)
oldConn.Close()
delete(s.connections, sid)
@@ -219,14 +228,17 @@ func (s *Server) handleConnect(sid uint16, req ConnectRequest) {
log.Printf("Connected sid=%d", sid)
go func() {
defer func() {
s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 4096)
for {
n, err := conn.Read(buf)
if err != nil {
s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
return
}