mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-06-07 21:04:42 +00:00
Merge pull request #14 from openlibrecommunity/refactor/all
Refactor/all
This commit is contained in:
@@ -72,7 +72,7 @@ func RunWithReady(
|
||||
|
||||
key, err := decodeKey(keyHex)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("decodeKey failed: %w", err)
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
@@ -95,7 +95,7 @@ func RunWithReady(
|
||||
|
||||
for peerID := range 1 {
|
||||
if err := c.addPeer(runCtx, roomURL, peerID, cancel); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("addPeer failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,10 +111,6 @@ func RunWithReady(
|
||||
return err
|
||||
}
|
||||
|
||||
func peerCount() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func decodeKey(keyHex string) ([]byte, error) {
|
||||
if keyHex == "" {
|
||||
key := make([]byte, 32)
|
||||
@@ -193,7 +189,7 @@ func (c *Client) addPeer(
|
||||
peerID int,
|
||||
cancel context.CancelFunc,
|
||||
) error {
|
||||
peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData)
|
||||
peer, err := telemost.NewPeer(runCtx, roomURL, names.Generate(), c.onData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer %d: %w", peerID, err)
|
||||
}
|
||||
@@ -257,7 +253,7 @@ func (c *Client) sendResetSignal() {
|
||||
func (c *Client) onData(data []byte) {
|
||||
plaintext, err := c.cipher.Decrypt(data)
|
||||
if err != nil {
|
||||
logger.Debug("Decrypt error: %v", err)
|
||||
logger.Debugf("Decrypt error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -292,7 +288,7 @@ func (c *Client) runSOCKS5(
|
||||
<-ctx.Done()
|
||||
log.Println("Closing SOCKS5 listener...")
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Debug("SOCKS5 listener close error: %v", err)
|
||||
logger.Debugf("SOCKS5 listener close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -317,7 +313,7 @@ func (c *Client) runSOCKS5(
|
||||
func (c *Client) closePeers() {
|
||||
for _, peer := range c.peers {
|
||||
if err := peer.Close(); err != nil {
|
||||
logger.Debug("Peer close error: %v", err)
|
||||
logger.Debugf("Peer close error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -326,7 +322,7 @@ func (c *Client) closePeers() {
|
||||
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debug("SOCKS5 connection close error: %v", err)
|
||||
logger.Debugf("SOCKS5 connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -362,7 +358,7 @@ func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
|
||||
}
|
||||
|
||||
sid := c.mux.OpenStream()
|
||||
logger.Verbose("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
|
||||
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
|
||||
log.Printf("[CLIENT] sid=%d SOCKS5_START %s:%d", sid, addr, port)
|
||||
|
||||
if !c.sendConnectRequest(sid, addr, port) {
|
||||
@@ -481,12 +477,12 @@ func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
|
||||
Port: port,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Debug("Connect request marshal error: %v", err)
|
||||
logger.Debugf("Connect request marshal error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := c.mux.SendData(sid, reqData); err != nil {
|
||||
logger.Debug("Connect request send error: %v", err)
|
||||
logger.Debugf("Connect request send error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -525,7 +521,7 @@ func (c *Client) proxyStream(conn net.Conn, sid uint16) {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err := c.mux.CloseStream(sid); err != nil {
|
||||
logger.Debug("Close stream error: %v", err)
|
||||
logger.Debugf("Close stream error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -579,7 +575,7 @@ func writeStreamData(conn net.Conn, data []byte) bool {
|
||||
|
||||
func writeResponse(conn net.Conn, response []byte) {
|
||||
if _, err := conn.Write(response); err != nil {
|
||||
logger.Debug("SOCKS5 response write error: %v", err)
|
||||
logger.Debugf("SOCKS5 response write error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,48 +1,59 @@
|
||||
// Package crypto provides cryptographic functions.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type Cipher struct {
|
||||
var (
|
||||
ErrInvalidKeySize = errors.New("invalid key size") //nolint:revive
|
||||
ErrCiphertextTooShort = errors.New("ciphertext too short") //nolint:revive
|
||||
)
|
||||
|
||||
type Cipher struct { //nolint:revive
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func NewCipher(keyStr string) (*Cipher, error) {
|
||||
func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
|
||||
key := []byte(keyStr)
|
||||
if len(key) != chacha20poly1305.KeySize {
|
||||
return nil, errors.New("invalid key size")
|
||||
return nil, ErrInvalidKeySize
|
||||
}
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to create aead: %w", err)
|
||||
}
|
||||
|
||||
return &Cipher{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { //nolint:revive
|
||||
nonce := make([]byte, c.aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to read nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := c.aead.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { //nolint:revive
|
||||
if len(ciphertext) < c.aead.NonceSize() {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
return nil, ErrCiphertextTooShort
|
||||
}
|
||||
|
||||
nonce := ciphertext[:c.aead.NonceSize()]
|
||||
encrypted := ciphertext[c.aead.NonceSize():]
|
||||
|
||||
return c.aead.Open(nil, nonce, encrypted, nil)
|
||||
res, err := c.aead.Open(nil, nonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
package logger
|
||||
package logger //nolint:revive
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var verboseEnabled atomic.Bool
|
||||
var verboseEnabled atomic.Bool //nolint:gochecknoglobals
|
||||
|
||||
func SetVerbose(enabled bool) {
|
||||
func SetVerbose(enabled bool) { //nolint:revive
|
||||
verboseEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
func IsVerbose() bool {
|
||||
func IsVerbose() bool { //nolint:revive
|
||||
return verboseEnabled.Load()
|
||||
}
|
||||
|
||||
func Verbose(format string, v ...interface{}) {
|
||||
func Verbosef(format string, v ...interface{}) { //nolint:revive
|
||||
if verboseEnabled.Load() {
|
||||
log.Printf("[VERBOSE] "+format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func Debug(format string, v ...interface{}) {
|
||||
func Debugf(format string, v ...interface{}) { //nolint:revive
|
||||
if verboseEnabled.Load() {
|
||||
log.Printf("[DEBUG] "+format, v...)
|
||||
}
|
||||
|
||||
@@ -1,31 +1,33 @@
|
||||
// ===========================================
|
||||
// AI GENERATED / AI GENERATED / AI GENERATED
|
||||
//===========================================
|
||||
|
||||
// Package mux provides a multiplexer for multiple streams over a single connection.
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive
|
||||
)
|
||||
|
||||
const (
|
||||
ControlStreamID uint16 = 0xFFFF
|
||||
ControlLength uint16 = 0xFFFF
|
||||
ControlStreamID uint16 = 0xFFFF //nolint:revive
|
||||
ControlLength uint16 = 0xFFFF //nolint:revive
|
||||
|
||||
ControlResetClient uint32 = 1
|
||||
)
|
||||
|
||||
type ControlFrame struct {
|
||||
type ControlFrame struct { //nolint:revive
|
||||
ClientID uint32
|
||||
Type uint32
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
type Stream struct { //nolint:revive
|
||||
ID uint16
|
||||
ClientID uint32
|
||||
recvBuf []byte
|
||||
@@ -35,13 +37,13 @@ type Stream struct {
|
||||
outOfOrder map[uint32][]byte
|
||||
}
|
||||
|
||||
func (s *Stream) RecvBuf() []byte {
|
||||
func (s *Stream) RecvBuf() []byte { //nolint:revive
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.recvBuf
|
||||
}
|
||||
|
||||
type Multiplexer struct {
|
||||
type Multiplexer struct { //nolint:revive
|
||||
streams map[uint16]*Stream
|
||||
nextID uint16
|
||||
clientID uint32
|
||||
@@ -55,7 +57,7 @@ type Multiplexer struct {
|
||||
sendSeqMu sync.Mutex
|
||||
}
|
||||
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive
|
||||
return &Multiplexer{
|
||||
streams: make(map[uint16]*Stream),
|
||||
nextID: 1,
|
||||
@@ -68,7 +70,7 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) OpenStream() uint16 {
|
||||
func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -91,7 +93,7 @@ func (m *Multiplexer) OpenStream() uint16 {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
|
||||
m.mu.RLock()
|
||||
stream, exists := m.streams[sid]
|
||||
m.mu.RUnlock()
|
||||
@@ -100,12 +102,11 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep encrypted DataChannel messages below Telemost's observed 8 KiB cap.
|
||||
const chunkSize = 7000
|
||||
totalChunks := (len(data) + chunkSize - 1) / chunkSize
|
||||
|
||||
if totalChunks > 10 {
|
||||
logger.Debug("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
|
||||
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
|
||||
}
|
||||
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
@@ -124,19 +125,19 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
frame := make([]byte, 12+len(chunk))
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk)))
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec
|
||||
binary.BigEndian.PutUint32(frame[8:12], seq)
|
||||
copy(frame[12:], chunk)
|
||||
|
||||
if err := m.onSend(frame); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -154,17 +155,23 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
binary.BigEndian.PutUint16(frame[6:8], 0)
|
||||
binary.BigEndian.PutUint32(frame[8:12], 0)
|
||||
|
||||
return m.onSend(frame)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) SendClientReset() error {
|
||||
if m.clientID == 0 {
|
||||
return errors.New("client reset requires a non-zero client id")
|
||||
if err := m.onSend(frame); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
return m.onSend(BuildControlFrame(m.clientID, ControlResetClient))
|
||||
return nil
|
||||
}
|
||||
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
func (m *Multiplexer) SendClientReset() error { //nolint:revive
|
||||
if m.clientID == 0 {
|
||||
return ErrClientResetID
|
||||
}
|
||||
if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive
|
||||
frame := make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(frame[0:4], clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
|
||||
@@ -173,7 +180,7 @@ func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
return frame
|
||||
}
|
||||
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
|
||||
if len(frame) < 12 {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
@@ -190,7 +197,7 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if ok {
|
||||
m.handleControlFrame(control)
|
||||
@@ -207,11 +214,7 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
seq := binary.BigEndian.Uint32(frame[8:12])
|
||||
|
||||
if length == 0 {
|
||||
m.mu.Lock()
|
||||
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
}
|
||||
m.mu.Unlock()
|
||||
m.handleCloseStreamFrame(sid, clientID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -219,15 +222,45 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
data := frame[12 : 12+length]
|
||||
m.processDataFrame(sid, clientID, seq, frame[12:12+length])
|
||||
}
|
||||
|
||||
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stream := m.getOrCreateStream(sid, clientID)
|
||||
if stream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if seq == stream.nextSeq {
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil {
|
||||
s.recvBuf = append(s.recvBuf, data...)
|
||||
s.nextSeq++
|
||||
m.applyOutOfOrder(s, sid, clientID)
|
||||
m.notifyDataReady(sid)
|
||||
}
|
||||
} else if seq > stream.nextSeq {
|
||||
if len(stream.outOfOrder) < 100 {
|
||||
stream.outOfOrder[seq] = append([]byte(nil), data...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
|
||||
stream, exists := m.streams[sid]
|
||||
if !exists {
|
||||
if len(m.streams) >= m.maxStreams {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
stream = &Stream{
|
||||
ID: sid,
|
||||
@@ -237,59 +270,42 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
outOfOrder: make(map[uint32][]byte),
|
||||
}
|
||||
m.streams[sid] = stream
|
||||
} else if stream.ClientID != clientID {
|
||||
return stream
|
||||
}
|
||||
|
||||
if stream.ClientID != clientID {
|
||||
stream.ClientID = clientID
|
||||
stream.recvBuf = make([]byte, 0)
|
||||
stream.closed = false
|
||||
stream.nextSeq = 0
|
||||
stream.outOfOrder = make(map[uint32][]byte)
|
||||
}
|
||||
return stream
|
||||
}
|
||||
|
||||
if seq == stream.nextSeq {
|
||||
// Backpressure: if the stream buffer is full, release the mux lock and
|
||||
// wait for the reader to drain it. Dropping/closing here would corrupt
|
||||
// the TCP stream carried over the mux — large HTTP/2 downloads (X,
|
||||
// Instagram, YouTube) that push data faster than conn.Write can accept
|
||||
// would lose bytes and hang forever.
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(data)); s == nil {
|
||||
func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) {
|
||||
for {
|
||||
nextData, ok := stream.outOfOrder[stream.nextSeq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
|
||||
return
|
||||
} else {
|
||||
stream = s
|
||||
}
|
||||
stream.recvBuf = append(stream.recvBuf, data...)
|
||||
stream.recvBuf = append(stream.recvBuf, nextData...)
|
||||
delete(stream.outOfOrder, stream.nextSeq)
|
||||
stream.nextSeq++
|
||||
logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
nextData, ok := stream.outOfOrder[stream.nextSeq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
|
||||
return
|
||||
} else {
|
||||
stream = s
|
||||
}
|
||||
nextData, ok = stream.outOfOrder[stream.nextSeq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
stream.recvBuf = append(stream.recvBuf, nextData...)
|
||||
delete(stream.outOfOrder, stream.nextSeq)
|
||||
stream.nextSeq++
|
||||
logger.Verbose("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
|
||||
}
|
||||
|
||||
m.dataReadyMu.Lock()
|
||||
if ch, ok := m.dataReady[sid]; ok {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
m.dataReadyMu.Unlock()
|
||||
} else if seq > stream.nextSeq {
|
||||
if len(stream.outOfOrder) < 100 {
|
||||
stream.outOfOrder[seq] = append([]byte(nil), data...)
|
||||
func (m *Multiplexer) notifyDataReady(sid uint16) {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
if ch, ok := m.dataReady[sid]; ok {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -299,11 +315,11 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
|
||||
case ControlResetClient:
|
||||
m.ResetClient(control.ClientID)
|
||||
default:
|
||||
logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
|
||||
logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -315,10 +331,6 @@ func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
// waitForBufferSpace releases m.mu and waits until the stream's recvBuf has
|
||||
// room for `need` more bytes, then re-acquires the lock. Returns the (possibly
|
||||
// re-fetched) stream, or nil if the stream disappeared / was reset / closed.
|
||||
// Caller must hold m.mu (write-locked) on entry and will hold it on return.
|
||||
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
|
||||
for {
|
||||
stream, ok := m.streams[sid]
|
||||
@@ -334,7 +346,7 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -348,7 +360,7 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
@@ -356,7 +368,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
return !exists || stream.closed
|
||||
}
|
||||
|
||||
func (m *Multiplexer) GetStreams() []uint16 {
|
||||
func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
@@ -367,13 +379,13 @@ func (m *Multiplexer) GetStreams() []uint16 {
|
||||
return sids
|
||||
}
|
||||
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream {
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.streams[sid]
|
||||
}
|
||||
|
||||
func (m *Multiplexer) Reset() {
|
||||
func (m *Multiplexer) Reset() { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -389,14 +401,14 @@ func (m *Multiplexer) Reset() {
|
||||
m.sendSeqMu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onSend = onSend
|
||||
}
|
||||
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
@@ -406,7 +418,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
return m.dataReady[sid]
|
||||
}
|
||||
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
// Package protect provides functions to protect sockets from VPN routing.
|
||||
package protect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"syscall"
|
||||
@@ -10,18 +12,21 @@ import (
|
||||
|
||||
// Protector is called with a socket file descriptor before connect.
|
||||
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
|
||||
var Protector func(fd int) bool
|
||||
var Protector func(fd int) bool //nolint:gochecknoglobals
|
||||
|
||||
func controlFunc(network, address string, c syscall.RawConn) error {
|
||||
func controlFunc(network, _ string, c syscall.RawConn) error {
|
||||
if Protector == nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
c.Control(func(fd uintptr) {
|
||||
if !Protector(int(fd)) {
|
||||
controlErr := c.Control(func(fd uintptr) {
|
||||
if !Protector(int(fd)) { //nolint:gosec
|
||||
err = &net.OpError{Op: "protect", Net: network, Err: net.ErrClosed}
|
||||
}
|
||||
})
|
||||
if controlErr != nil {
|
||||
return fmt.Errorf("control failed: %w", controlErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -50,17 +55,27 @@ func NewHTTPClient() *http.Client {
|
||||
|
||||
// DialContext dials using a protected socket.
|
||||
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return NewDialer().DialContext(ctx, network, address)
|
||||
conn, err := NewDialer().DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// proxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE.
|
||||
type proxyDialer struct{}
|
||||
// ProxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE.
|
||||
type ProxyDialer struct{}
|
||||
|
||||
func (d *proxyDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
return NewDialer().Dial(network, addr)
|
||||
// Dial connects to the address on the named network using a protected socket.
|
||||
func (d *ProxyDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
conn, err := NewDialer().Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewProxyDialer returns a proxy.Dialer that protects ICE sockets.
|
||||
func NewProxyDialer() *proxyDialer {
|
||||
return &proxyDialer{}
|
||||
func NewProxyDialer() *ProxyDialer {
|
||||
return &ProxyDialer{}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package server implements the olcrtc tunnel server logic.
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -5,10 +6,12 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -21,7 +24,24 @@ import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
var (
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeyStringLength is returned when the encryption key string length is not 32.
|
||||
ErrKeyStringLength = errors.New("key string length must be 32")
|
||||
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
|
||||
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
|
||||
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
|
||||
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
|
||||
// ErrNoPeers is returned when no peers are available.
|
||||
ErrNoPeers = errors.New("no peers available")
|
||||
// ErrDialProxy is returned when dialing the proxy fails.
|
||||
ErrDialProxy = errors.New("failed to dial proxy")
|
||||
// ErrEncryptFailed is returned when encryption fails.
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
)
|
||||
|
||||
type Server struct { //nolint:revive
|
||||
peers []*telemost.Peer
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
@@ -33,48 +53,32 @@ type Server struct {
|
||||
activeClients atomic.Int32
|
||||
wg sync.WaitGroup
|
||||
dnsServer string
|
||||
dnsCache sync.Map
|
||||
resolver *net.Resolver
|
||||
socksProxyAddr string
|
||||
socksProxyPort int
|
||||
}
|
||||
|
||||
type ConnectRequest struct {
|
||||
type ConnectRequest struct { //nolint:revive
|
||||
Cmd string `json:"cmd"`
|
||||
Addr string `json:"addr"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr string, socksProxyPort int) error {
|
||||
// Run starts the olcrtc server and listens for client connections.
|
||||
func Run(
|
||||
ctx context.Context,
|
||||
roomURL,
|
||||
keyHex string,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
) error {
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
var key []byte
|
||||
var err error
|
||||
|
||||
if keyHex == "" {
|
||||
key = make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Generated key: %x", key)
|
||||
} else {
|
||||
key, err = hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
cipher, err := setupCipher(keyHex)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
@@ -87,20 +91,72 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr
|
||||
socksProxyPort: socksProxyPort,
|
||||
}
|
||||
|
||||
if dnsServer == "" {
|
||||
dnsServer = "1.1.1.1:53"
|
||||
if s.dnsServer == "" {
|
||||
s.dnsServer = "1.1.1.1:53"
|
||||
}
|
||||
|
||||
s.setupResolver()
|
||||
s.setupMux()
|
||||
|
||||
const peerCount = 1
|
||||
for i := range peerCount {
|
||||
if err := s.addPeer(runCtx, roomURL, i, cancel); err != nil {
|
||||
return fmt.Errorf("addPeer failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.runLoop(runCtx)
|
||||
|
||||
log.Println("Waiting for server goroutines...")
|
||||
s.wg.Wait()
|
||||
log.Println("Server goroutines finished")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
var key []byte
|
||||
var err error
|
||||
|
||||
if keyHex == "" {
|
||||
key = make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
log.Printf("Generated key: %x", key)
|
||||
} else {
|
||||
key, err = hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
|
||||
}
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
func (s *Server) setupResolver() {
|
||||
s.resolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
return d.DialContext(ctx, network, dnsServer)
|
||||
return d.DialContext(ctx, network, s.dnsServer)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
peerCount := 1
|
||||
|
||||
func (s *Server) setupMux() {
|
||||
s.mux = mux.New(0, func(frame []byte) error {
|
||||
for {
|
||||
canSend := true
|
||||
@@ -118,110 +174,118 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr
|
||||
|
||||
encrypted, err := s.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
|
||||
if len(s.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
|
||||
return s.peers[idx].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
for i := 0; i < peerCount; i++ {
|
||||
peerID := i
|
||||
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
s.peers = append(s.peers, peer)
|
||||
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
if dc == nil {
|
||||
log.Printf("Server peer %d channel closed - resetting multiplexer state", peerID)
|
||||
} else {
|
||||
log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID)
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
for sid, conn := range s.connections {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
if dc != nil {
|
||||
s.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := s.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
|
||||
return s.peers[idx].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
s.mux.Reset()
|
||||
|
||||
log.Println("Server multiplexer reset complete")
|
||||
})
|
||||
|
||||
peer.SetShouldReconnect(func() bool {
|
||||
return s.activeClients.Load() > 0
|
||||
})
|
||||
|
||||
log.Printf("Connecting peer %d to Telemost...", peerID)
|
||||
if err := peer.Connect(runCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
peer.WatchConnection(runCtx)
|
||||
}()
|
||||
func (s *Server) addPeer(ctx context.Context, roomURL string, peerID int, cancel context.CancelFunc) error {
|
||||
peer, err := telemost.NewPeer(ctx, roomURL, names.Generate(), s.onData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create peer: %w", err)
|
||||
}
|
||||
|
||||
err = s.run(runCtx)
|
||||
peer.SetEndedCallback(func(reason string) {
|
||||
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
|
||||
cancel()
|
||||
})
|
||||
s.peers = append(s.peers, peer)
|
||||
|
||||
log.Println("Waiting for server goroutines...")
|
||||
s.wg.Wait()
|
||||
log.Println("Server goroutines finished")
|
||||
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
|
||||
s.handlePeerReconnect(peerID, dc)
|
||||
})
|
||||
|
||||
return err
|
||||
peer.SetShouldReconnect(func() bool {
|
||||
return s.activeClients.Load() > 0
|
||||
})
|
||||
|
||||
log.Printf("Connecting peer %d to Telemost...", peerID)
|
||||
if err := peer.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect peer: %w", err)
|
||||
}
|
||||
log.Printf("Peer %d connected", peerID)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
peer.WatchConnection(ctx)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
|
||||
if dc == nil {
|
||||
log.Printf("Server peer %d channel closed - resetting mux state", peerID)
|
||||
} else {
|
||||
log.Printf("Server peer %d reconnected - resetting mux state", peerID)
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
for sid, conn := range s.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
if dc != nil {
|
||||
s.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := s.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(s.peers) == 0 {
|
||||
return ErrNoPeers
|
||||
}
|
||||
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
|
||||
return s.peers[idx].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
s.mux.Reset()
|
||||
log.Println("Server multiplexer reset complete")
|
||||
}
|
||||
|
||||
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
|
||||
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to write socks5 auth: %w", err)
|
||||
}
|
||||
|
||||
resp := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, resp); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
|
||||
}
|
||||
if resp[0] != 5 || resp[1] != 0 {
|
||||
return fmt.Errorf("SOCKS5 auth failed")
|
||||
return ErrSocks5AuthFailed
|
||||
}
|
||||
|
||||
req := []byte{5, 1, 0, 3}
|
||||
req = append(req, byte(len(targetAddr)))
|
||||
addrLen := len(targetAddr)
|
||||
if addrLen > 255 {
|
||||
addrLen = 255
|
||||
targetAddr = targetAddr[:255]
|
||||
}
|
||||
|
||||
req := make([]byte, 0, 7+addrLen)
|
||||
req = append(req, 5, 1, 0, 3, byte(addrLen))
|
||||
req = append(req, []byte(targetAddr)...)
|
||||
req = append(req, byte(targetPort>>8), byte(targetPort))
|
||||
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec
|
||||
|
||||
if _, err := conn.Write(req); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to write socks5 connect req: %w", err)
|
||||
}
|
||||
|
||||
resp = make([]byte, 10)
|
||||
if _, err := io.ReadFull(conn, resp); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
|
||||
}
|
||||
if resp[0] != 5 || resp[1] != 0 {
|
||||
return fmt.Errorf("SOCKS5 connect failed: %d", resp[1])
|
||||
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -230,12 +294,12 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int)
|
||||
func (s *Server) onData(data []byte) {
|
||||
plaintext, err := s.cipher.Decrypt(data)
|
||||
if err != nil {
|
||||
logger.Debug("Decrypt error: %v", err)
|
||||
logger.Debugf("Decrypt error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
|
||||
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID)
|
||||
log.Printf("Received reset signal from client (clientID=%d)", control.ClientID)
|
||||
s.closeClientConnections(control.ClientID)
|
||||
}
|
||||
|
||||
@@ -250,63 +314,67 @@ func (s *Server) closeClientConnections(clientID uint32) {
|
||||
stream := s.mux.GetStream(streamSid)
|
||||
if stream != nil && stream.ClientID == clientID {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(s.connections, streamSid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) error {
|
||||
func (s *Server) runLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Server shutting down...")
|
||||
s.connMu.Lock()
|
||||
for _, conn := range s.connections {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
log.Printf("Closing %d peer(s)...", len(s.peers))
|
||||
for i, peer := range s.peers {
|
||||
log.Printf("Closing peer %d...", i)
|
||||
peer.Close()
|
||||
}
|
||||
log.Println("All peers closed")
|
||||
|
||||
s.shutdown()
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
s.processMuxStreams(ctx)
|
||||
}
|
||||
sids := s.mux.GetStreams()
|
||||
}
|
||||
}
|
||||
|
||||
for _, sid := range sids {
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnection(sid)
|
||||
continue
|
||||
}
|
||||
func (s *Server) shutdown() {
|
||||
log.Println("Server shutting down...")
|
||||
s.connMu.Lock()
|
||||
for _, conn := range s.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
if s.hasConnection(sid) {
|
||||
continue
|
||||
}
|
||||
for i, peer := range s.peers {
|
||||
log.Printf("Closing peer %d...", i)
|
||||
_ = peer.Close()
|
||||
}
|
||||
log.Println("All peers closed")
|
||||
}
|
||||
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
func (s *Server) processMuxStreams(ctx context.Context) {
|
||||
sids := s.mux.GetStreams()
|
||||
for _, sid := range sids {
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnection(sid)
|
||||
continue
|
||||
}
|
||||
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
|
||||
s.closeStreamConnection(sid)
|
||||
go s.handleConnect(ctx, sid, req)
|
||||
}
|
||||
if s.hasConnection(sid) {
|
||||
continue
|
||||
}
|
||||
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
log.Printf("[SERVER] sid=%d RECV_CONNECT %s:%d", sid, req.Addr, req.Port)
|
||||
s.closeStreamConnection(sid)
|
||||
go s.handleConnect(ctx, sid, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -314,15 +382,14 @@ func (s *Server) run(ctx context.Context) error {
|
||||
func (s *Server) hasConnection(sid uint16) bool {
|
||||
s.connMu.RLock()
|
||||
defer s.connMu.RUnlock()
|
||||
conn := s.connections[sid]
|
||||
return conn != nil
|
||||
return s.connections[sid] != nil
|
||||
}
|
||||
|
||||
func (s *Server) closeStreamConnection(sid uint16) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
@@ -332,7 +399,7 @@ func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn == expected {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
@@ -344,7 +411,7 @@ func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
|
||||
if current := s.streamPumps[sid]; current == conn {
|
||||
return false
|
||||
} else if current != nil {
|
||||
current.Close()
|
||||
_ = current.Close()
|
||||
}
|
||||
s.streamPumps[sid] = conn
|
||||
return true
|
||||
@@ -360,102 +427,103 @@ func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
|
||||
|
||||
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
|
||||
startTime := time.Now()
|
||||
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
|
||||
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
log.Printf("[SERVER] sid=%d CONNECT_START %s", sid, addr)
|
||||
|
||||
s.connMu.Lock()
|
||||
oldConn, exists := s.connections[sid]
|
||||
if exists && oldConn != nil {
|
||||
log.Printf("Closing old connection for sid=%d", sid)
|
||||
oldConn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
s.closeStreamConnection(sid)
|
||||
|
||||
dialStart := time.Now()
|
||||
var conn net.Conn
|
||||
var err error
|
||||
conn, err := s.dial(req)
|
||||
dialElapsed := time.Since(dialStart)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial=%v total=%v err=%v",
|
||||
sid, dialElapsed, time.Since(startTime), err)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
return
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
s.connections[sid] = conn
|
||||
s.connMu.Unlock()
|
||||
|
||||
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial=%v", sid, dialElapsed)
|
||||
|
||||
s.activeClients.Add(1)
|
||||
_ = s.mux.SendData(sid, []byte{0x00})
|
||||
s.startStreamPump(ctx, sid, conn)
|
||||
|
||||
go s.pumpToMux(sid, conn)
|
||||
}
|
||||
|
||||
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
if s.socksProxyAddr == "" {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: s.resolver,
|
||||
}
|
||||
conn, err = dialer.Dial("tcp4", addr)
|
||||
logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid)
|
||||
} else {
|
||||
proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
conn, err := dialer.Dial("tcp4", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
conn, err = dialer.Dial("tcp4", proxyAddr)
|
||||
if err == nil {
|
||||
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
||||
conn.Close()
|
||||
err = fmt.Errorf("SOCKS5 connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid)
|
||||
return conn, nil
|
||||
}
|
||||
dialElapsed := time.Since(dialStart)
|
||||
|
||||
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp4", proxyAddr)
|
||||
if err != nil {
|
||||
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err)
|
||||
go s.mux.CloseStream(sid)
|
||||
return
|
||||
return nil, fmt.Errorf("failed to dial proxy: %w", err)
|
||||
}
|
||||
|
||||
logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid)
|
||||
s.connMu.Lock()
|
||||
s.connections[sid] = conn
|
||||
s.connMu.Unlock()
|
||||
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
|
||||
|
||||
s.activeClients.Add(1)
|
||||
s.mux.SendData(sid, []byte{0x00})
|
||||
s.startStreamPump(ctx, sid, conn)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
s.activeClients.Add(-1)
|
||||
s.mux.CloseStream(sid)
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, sid)
|
||||
s.connMu.Unlock()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
totalSent := uint64(0)
|
||||
lastLog := time.Now()
|
||||
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if totalSent > 1024*1024 {
|
||||
log.Printf("[SERVER] sid=%d TRANSFER_COMPLETE total=%d MB", sid, totalSent/(1024*1024))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for !s.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := s.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
totalSent += uint64(n)
|
||||
if time.Since(lastLog) > 5*time.Second {
|
||||
log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))
|
||||
lastLog = time.Now()
|
||||
}
|
||||
}
|
||||
func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
|
||||
defer func() {
|
||||
s.activeClients.Add(-1)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, sid)
|
||||
s.connMu.Unlock()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
totalSent := uint64(0)
|
||||
lastLog := time.Now()
|
||||
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if totalSent > 1024*1024 {
|
||||
log.Printf("[SERVER] sid=%d TRANSFER_DONE total=%d MB", sid, totalSent/(1024*1024))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for !s.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := s.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
totalSent += uint64(n) //nolint:gosec
|
||||
if time.Since(lastLog) > 5*time.Second {
|
||||
log.Printf("[SERVER] sid=%d TRANSFER_UP sent=%d MB", sid, totalSent/(1024*1024))
|
||||
lastLog = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
@@ -479,7 +547,7 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn)
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
s.mux.CloseStream(sid)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
s.closeStreamConnectionIfCurrent(sid, conn)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package telemost
|
||||
package telemost //nolint:revive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -13,21 +15,23 @@ import (
|
||||
|
||||
const apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost"
|
||||
|
||||
type ConnectionInfo struct {
|
||||
RoomID string `json:"room_id"`
|
||||
PeerID string `json:"peer_id"`
|
||||
Credentials string `json:"credentials"`
|
||||
var ErrAPI = errors.New("api error") //nolint:revive
|
||||
|
||||
type ConnectionInfo struct { //nolint:revive
|
||||
RoomID string `json:"room_id"` //nolint:tagliatelle
|
||||
PeerID string `json:"peer_id"` //nolint:tagliatelle
|
||||
Credentials string `json:"credentials"` //nolint:tagliatelle
|
||||
ClientConfig struct {
|
||||
MediaServerURL string `json:"media_server_url"`
|
||||
} `json:"client_configuration"`
|
||||
MediaServerURL string `json:"media_server_url"` //nolint:tagliatelle
|
||||
} `json:"client_configuration"` //nolint:tagliatelle
|
||||
}
|
||||
|
||||
func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) {
|
||||
func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*ConnectionInfo, error) { //nolint:revive
|
||||
u := fmt.Sprintf("%s/conferences/%s/connection", apiBase, url.QueryEscape(roomURL))
|
||||
|
||||
req, err := http.NewRequest("GET", u, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
@@ -48,18 +52,18 @@ func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) {
|
||||
client := protect.NewHTTPClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var info ConnectionInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user