Merge pull request #14 from openlibrecommunity/refactor/all

Refactor/all
This commit is contained in:
zarazaex
2026-04-12 23:49:55 +03:00
committed by GitHub
8 changed files with 865 additions and 769 deletions

View File

@@ -72,7 +72,7 @@ func RunWithReady(
key, err := decodeKey(keyHex)
if err != nil {
return err
return fmt.Errorf("decodeKey failed: %w", err)
}
keyStr := string(key)
@@ -95,7 +95,7 @@ func RunWithReady(
for peerID := range 1 {
if err := c.addPeer(runCtx, roomURL, peerID, cancel); err != nil {
return err
return fmt.Errorf("addPeer failed: %w", err)
}
}
@@ -111,10 +111,6 @@ func RunWithReady(
return err
}
func peerCount() int {
return 1
}
func decodeKey(keyHex string) ([]byte, error) {
if keyHex == "" {
key := make([]byte, 32)
@@ -193,7 +189,7 @@ func (c *Client) addPeer(
peerID int,
cancel context.CancelFunc,
) error {
peer, err := telemost.NewPeer(roomURL, names.Generate(), c.onData)
peer, err := telemost.NewPeer(runCtx, roomURL, names.Generate(), c.onData)
if err != nil {
return fmt.Errorf("create peer %d: %w", peerID, err)
}
@@ -257,7 +253,7 @@ func (c *Client) sendResetSignal() {
func (c *Client) onData(data []byte) {
plaintext, err := c.cipher.Decrypt(data)
if err != nil {
logger.Debug("Decrypt error: %v", err)
logger.Debugf("Decrypt error: %v", err)
return
}
@@ -292,7 +288,7 @@ func (c *Client) runSOCKS5(
<-ctx.Done()
log.Println("Closing SOCKS5 listener...")
if err := listener.Close(); err != nil {
logger.Debug("SOCKS5 listener close error: %v", err)
logger.Debugf("SOCKS5 listener close error: %v", err)
}
}()
@@ -317,7 +313,7 @@ func (c *Client) runSOCKS5(
func (c *Client) closePeers() {
for _, peer := range c.peers {
if err := peer.Close(); err != nil {
logger.Debug("Peer close error: %v", err)
logger.Debugf("Peer close error: %v", err)
}
}
}
@@ -326,7 +322,7 @@ func (c *Client) closePeers() {
func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
defer func() {
if err := conn.Close(); err != nil {
logger.Debug("SOCKS5 connection close error: %v", err)
logger.Debugf("SOCKS5 connection close error: %v", err)
}
}()
@@ -362,7 +358,7 @@ func (c *Client) handleSOCKS5(conn net.Conn, username, password string) {
}
sid := c.mux.OpenStream()
logger.Verbose("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
logger.Verbosef("SOCKS5 opened stream sid=%d for %s:%d", sid, addr, port)
log.Printf("[CLIENT] sid=%d SOCKS5_START %s:%d", sid, addr, port)
if !c.sendConnectRequest(sid, addr, port) {
@@ -481,12 +477,12 @@ func (c *Client) sendConnectRequest(sid uint16, addr string, port uint16) bool {
Port: port,
})
if err != nil {
logger.Debug("Connect request marshal error: %v", err)
logger.Debugf("Connect request marshal error: %v", err)
return false
}
if err := c.mux.SendData(sid, reqData); err != nil {
logger.Debug("Connect request send error: %v", err)
logger.Debugf("Connect request send error: %v", err)
return false
}
@@ -525,7 +521,7 @@ func (c *Client) proxyStream(conn net.Conn, sid uint16) {
n, err := conn.Read(buf)
if err != nil {
if err := c.mux.CloseStream(sid); err != nil {
logger.Debug("Close stream error: %v", err)
logger.Debugf("Close stream error: %v", err)
}
return
}
@@ -579,7 +575,7 @@ func writeStreamData(conn net.Conn, data []byte) bool {
func writeResponse(conn net.Conn, response []byte) {
if _, err := conn.Write(response); err != nil {
logger.Debug("SOCKS5 response write error: %v", err)
logger.Debugf("SOCKS5 response write error: %v", err)
}
}

View File

@@ -1,48 +1,59 @@
// Package crypto provides cryptographic functions.
package crypto
import (
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
type Cipher struct {
var (
ErrInvalidKeySize = errors.New("invalid key size") //nolint:revive
ErrCiphertextTooShort = errors.New("ciphertext too short") //nolint:revive
)
type Cipher struct { //nolint:revive
aead cipher.AEAD
}
func NewCipher(keyStr string) (*Cipher, error) {
func NewCipher(keyStr string) (*Cipher, error) { //nolint:revive
key := []byte(keyStr)
if len(key) != chacha20poly1305.KeySize {
return nil, errors.New("invalid key size")
return nil, ErrInvalidKeySize
}
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create aead: %w", err)
}
return &Cipher{aead: aead}, nil
}
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) {
func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { //nolint:revive
nonce := make([]byte, c.aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, err
return nil, fmt.Errorf("failed to read nonce: %w", err)
}
ciphertext := c.aead.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { //nolint:revive
if len(ciphertext) < c.aead.NonceSize() {
return nil, errors.New("ciphertext too short")
return nil, ErrCiphertextTooShort
}
nonce := ciphertext[:c.aead.NonceSize()]
encrypted := ciphertext[c.aead.NonceSize():]
return c.aead.Open(nil, nonce, encrypted, nil)
res, err := c.aead.Open(nil, nonce, encrypted, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt: %w", err)
}
return res, nil
}

View File

@@ -1,27 +1,27 @@
package logger
package logger //nolint:revive
import (
"log"
"sync/atomic"
)
var verboseEnabled atomic.Bool
var verboseEnabled atomic.Bool //nolint:gochecknoglobals
func SetVerbose(enabled bool) {
func SetVerbose(enabled bool) { //nolint:revive
verboseEnabled.Store(enabled)
}
func IsVerbose() bool {
func IsVerbose() bool { //nolint:revive
return verboseEnabled.Load()
}
func Verbose(format string, v ...interface{}) {
func Verbosef(format string, v ...interface{}) { //nolint:revive
if verboseEnabled.Load() {
log.Printf("[VERBOSE] "+format, v...)
}
}
func Debug(format string, v ...interface{}) {
func Debugf(format string, v ...interface{}) { //nolint:revive
if verboseEnabled.Load() {
log.Printf("[DEBUG] "+format, v...)
}

View File

@@ -1,31 +1,33 @@
// ===========================================
// AI GENERATED / AI GENERATED / AI GENERATED
//===========================================
// Package mux provides a multiplexer for multiple streams over a single connection.
package mux
import (
"encoding/binary"
"errors"
"fmt"
"sync"
"time"
"github.com/openlibrecommunity/olcrtc/internal/logger"
)
var (
ErrClientResetID = errors.New("client reset requires a non-zero client id") //nolint:revive
)
const (
ControlStreamID uint16 = 0xFFFF
ControlLength uint16 = 0xFFFF
ControlStreamID uint16 = 0xFFFF //nolint:revive
ControlLength uint16 = 0xFFFF //nolint:revive
ControlResetClient uint32 = 1
)
type ControlFrame struct {
type ControlFrame struct { //nolint:revive
ClientID uint32
Type uint32
}
type Stream struct {
type Stream struct { //nolint:revive
ID uint16
ClientID uint32
recvBuf []byte
@@ -35,13 +37,13 @@ type Stream struct {
outOfOrder map[uint32][]byte
}
func (s *Stream) RecvBuf() []byte {
func (s *Stream) RecvBuf() []byte { //nolint:revive
s.mu.Lock()
defer s.mu.Unlock()
return s.recvBuf
}
type Multiplexer struct {
type Multiplexer struct { //nolint:revive
streams map[uint16]*Stream
nextID uint16
clientID uint32
@@ -55,7 +57,7 @@ type Multiplexer struct {
sendSeqMu sync.Mutex
}
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
func New(clientID uint32, onSend func([]byte) error) *Multiplexer { //nolint:revive
return &Multiplexer{
streams: make(map[uint16]*Stream),
nextID: 1,
@@ -68,7 +70,7 @@ func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
}
}
func (m *Multiplexer) OpenStream() uint16 {
func (m *Multiplexer) OpenStream() uint16 { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
@@ -91,7 +93,7 @@ func (m *Multiplexer) OpenStream() uint16 {
}
}
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
func (m *Multiplexer) SendData(sid uint16, data []byte) error { //nolint:revive
m.mu.RLock()
stream, exists := m.streams[sid]
m.mu.RUnlock()
@@ -100,12 +102,11 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
return nil
}
// Keep encrypted DataChannel messages below Telemost's observed 8 KiB cap.
const chunkSize = 7000
totalChunks := (len(data) + chunkSize - 1) / chunkSize
if totalChunks > 10 {
logger.Debug("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
logger.Debugf("SendData: sid=%d, size=%d bytes, chunks=%d", sid, len(data), totalChunks)
}
for i := 0; i < len(data); i += chunkSize {
@@ -124,19 +125,19 @@ func (m *Multiplexer) SendData(sid uint16, data []byte) error {
frame := make([]byte, 12+len(chunk))
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
binary.BigEndian.PutUint16(frame[4:6], sid)
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk)))
binary.BigEndian.PutUint16(frame[6:8], uint16(uint32(len(chunk)))) //nolint:gosec
binary.BigEndian.PutUint32(frame[8:12], seq)
copy(frame[12:], chunk)
if err := m.onSend(frame); err != nil {
return err
return fmt.Errorf("onSend failed: %w", err)
}
}
return nil
}
func (m *Multiplexer) CloseStream(sid uint16) error {
func (m *Multiplexer) CloseStream(sid uint16) error { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
@@ -154,17 +155,23 @@ func (m *Multiplexer) CloseStream(sid uint16) error {
binary.BigEndian.PutUint16(frame[6:8], 0)
binary.BigEndian.PutUint32(frame[8:12], 0)
return m.onSend(frame)
}
func (m *Multiplexer) SendClientReset() error {
if m.clientID == 0 {
return errors.New("client reset requires a non-zero client id")
if err := m.onSend(frame); err != nil {
return fmt.Errorf("onSend failed: %w", err)
}
return m.onSend(BuildControlFrame(m.clientID, ControlResetClient))
return nil
}
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
func (m *Multiplexer) SendClientReset() error { //nolint:revive
if m.clientID == 0 {
return ErrClientResetID
}
if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil {
return fmt.Errorf("onSend failed: %w", err)
}
return nil
}
func BuildControlFrame(clientID uint32, controlType uint32) []byte { //nolint:revive
frame := make([]byte, 12)
binary.BigEndian.PutUint32(frame[0:4], clientID)
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
@@ -173,7 +180,7 @@ func BuildControlFrame(clientID uint32, controlType uint32) []byte {
return frame
}
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
func ParseControlFrame(frame []byte) (ControlFrame, bool) { //nolint:revive
if len(frame) < 12 {
return ControlFrame{}, false
}
@@ -190,7 +197,7 @@ func ParseControlFrame(frame []byte) (ControlFrame, bool) {
}, true
}
func (m *Multiplexer) HandleFrame(frame []byte) {
func (m *Multiplexer) HandleFrame(frame []byte) { //nolint:revive
control, ok := ParseControlFrame(frame)
if ok {
m.handleControlFrame(control)
@@ -207,11 +214,7 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
seq := binary.BigEndian.Uint32(frame[8:12])
if length == 0 {
m.mu.Lock()
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
stream.closed = true
}
m.mu.Unlock()
m.handleCloseStreamFrame(sid, clientID)
return
}
@@ -219,15 +222,45 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
return
}
data := frame[12 : 12+length]
m.processDataFrame(sid, clientID, seq, frame[12:12+length])
}
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
m.mu.Lock()
defer m.mu.Unlock()
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
stream.closed = true
}
}
func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) {
m.mu.Lock()
defer m.mu.Unlock()
stream := m.getOrCreateStream(sid, clientID)
if stream == nil {
return
}
if seq == stream.nextSeq {
if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil {
s.recvBuf = append(s.recvBuf, data...)
s.nextSeq++
m.applyOutOfOrder(s, sid, clientID)
m.notifyDataReady(sid)
}
} else if seq > stream.nextSeq {
if len(stream.outOfOrder) < 100 {
stream.outOfOrder[seq] = append([]byte(nil), data...)
}
}
}
func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
stream, exists := m.streams[sid]
if !exists {
if len(m.streams) >= m.maxStreams {
return
return nil
}
stream = &Stream{
ID: sid,
@@ -237,59 +270,42 @@ func (m *Multiplexer) HandleFrame(frame []byte) {
outOfOrder: make(map[uint32][]byte),
}
m.streams[sid] = stream
} else if stream.ClientID != clientID {
return stream
}
if stream.ClientID != clientID {
stream.ClientID = clientID
stream.recvBuf = make([]byte, 0)
stream.closed = false
stream.nextSeq = 0
stream.outOfOrder = make(map[uint32][]byte)
}
return stream
}
if seq == stream.nextSeq {
// Backpressure: if the stream buffer is full, release the mux lock and
// wait for the reader to drain it. Dropping/closing here would corrupt
// the TCP stream carried over the mux — large HTTP/2 downloads (X,
// Instagram, YouTube) that push data faster than conn.Write can accept
// would lose bytes and hang forever.
if s := m.waitForBufferSpace(sid, clientID, len(data)); s == nil {
func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) {
for {
nextData, ok := stream.outOfOrder[stream.nextSeq]
if !ok {
break
}
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
return
} else {
stream = s
}
stream.recvBuf = append(stream.recvBuf, data...)
stream.recvBuf = append(stream.recvBuf, nextData...)
delete(stream.outOfOrder, stream.nextSeq)
stream.nextSeq++
logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
}
}
for {
nextData, ok := stream.outOfOrder[stream.nextSeq]
if !ok {
break
}
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
return
} else {
stream = s
}
nextData, ok = stream.outOfOrder[stream.nextSeq]
if !ok {
break
}
stream.recvBuf = append(stream.recvBuf, nextData...)
delete(stream.outOfOrder, stream.nextSeq)
stream.nextSeq++
logger.Verbose("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
}
m.dataReadyMu.Lock()
if ch, ok := m.dataReady[sid]; ok {
select {
case ch <- struct{}{}:
default:
}
}
m.dataReadyMu.Unlock()
} else if seq > stream.nextSeq {
if len(stream.outOfOrder) < 100 {
stream.outOfOrder[seq] = append([]byte(nil), data...)
func (m *Multiplexer) notifyDataReady(sid uint16) {
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
if ch, ok := m.dataReady[sid]; ok {
select {
case ch <- struct{}{}:
default:
}
}
}
@@ -299,11 +315,11 @@ func (m *Multiplexer) handleControlFrame(control ControlFrame) {
case ControlResetClient:
m.ResetClient(control.ClientID)
default:
logger.Debug("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
}
}
func (m *Multiplexer) ResetClient(clientID uint32) {
func (m *Multiplexer) ResetClient(clientID uint32) { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
@@ -315,10 +331,6 @@ func (m *Multiplexer) ResetClient(clientID uint32) {
}
}
// waitForBufferSpace releases m.mu and waits until the stream's recvBuf has
// room for `need` more bytes, then re-acquires the lock. Returns the (possibly
// re-fetched) stream, or nil if the stream disappeared / was reset / closed.
// Caller must hold m.mu (write-locked) on entry and will hold it on return.
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
for {
stream, ok := m.streams[sid]
@@ -334,7 +346,7 @@ func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int)
}
}
func (m *Multiplexer) ReadStream(sid uint16) []byte {
func (m *Multiplexer) ReadStream(sid uint16) []byte { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
@@ -348,7 +360,7 @@ func (m *Multiplexer) ReadStream(sid uint16) []byte {
return data
}
func (m *Multiplexer) StreamClosed(sid uint16) bool {
func (m *Multiplexer) StreamClosed(sid uint16) bool { //nolint:revive
m.mu.RLock()
defer m.mu.RUnlock()
@@ -356,7 +368,7 @@ func (m *Multiplexer) StreamClosed(sid uint16) bool {
return !exists || stream.closed
}
func (m *Multiplexer) GetStreams() []uint16 {
func (m *Multiplexer) GetStreams() []uint16 { //nolint:revive
m.mu.RLock()
defer m.mu.RUnlock()
@@ -367,13 +379,13 @@ func (m *Multiplexer) GetStreams() []uint16 {
return sids
}
func (m *Multiplexer) GetStream(sid uint16) *Stream {
func (m *Multiplexer) GetStream(sid uint16) *Stream { //nolint:revive
m.mu.RLock()
defer m.mu.RUnlock()
return m.streams[sid]
}
func (m *Multiplexer) Reset() {
func (m *Multiplexer) Reset() { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
@@ -389,14 +401,14 @@ func (m *Multiplexer) Reset() {
m.sendSeqMu.Unlock()
}
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) { //nolint:revive
m.mu.Lock()
defer m.mu.Unlock()
m.onSend = onSend
}
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} { //nolint:revive
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()
@@ -406,7 +418,7 @@ func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
return m.dataReady[sid]
}
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
func (m *Multiplexer) CleanupDataChannel(sid uint16) { //nolint:revive
m.dataReadyMu.Lock()
defer m.dataReadyMu.Unlock()

View File

@@ -1,7 +1,9 @@
// Package protect provides functions to protect sockets from VPN routing.
package protect
import (
"context"
"fmt"
"net"
"net/http"
"syscall"
@@ -10,18 +12,21 @@ import (
// Protector is called with a socket file descriptor before connect.
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
var Protector func(fd int) bool
var Protector func(fd int) bool //nolint:gochecknoglobals
func controlFunc(network, address string, c syscall.RawConn) error {
func controlFunc(network, _ string, c syscall.RawConn) error {
if Protector == nil {
return nil
}
var err error
c.Control(func(fd uintptr) {
if !Protector(int(fd)) {
controlErr := c.Control(func(fd uintptr) {
if !Protector(int(fd)) { //nolint:gosec
err = &net.OpError{Op: "protect", Net: network, Err: net.ErrClosed}
}
})
if controlErr != nil {
return fmt.Errorf("control failed: %w", controlErr)
}
return err
}
@@ -50,17 +55,27 @@ func NewHTTPClient() *http.Client {
// DialContext dials using a protected socket.
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return NewDialer().DialContext(ctx, network, address)
conn, err := NewDialer().DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
// proxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE.
type proxyDialer struct{}
// ProxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE.
type ProxyDialer struct{}
func (d *proxyDialer) Dial(network, addr string) (net.Conn, error) {
return NewDialer().Dial(network, addr)
// Dial connects to the address on the named network using a protected socket.
func (d *ProxyDialer) Dial(network, addr string) (net.Conn, error) {
conn, err := NewDialer().Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
// NewProxyDialer returns a proxy.Dialer that protects ICE sockets.
func NewProxyDialer() *proxyDialer {
return &proxyDialer{}
func NewProxyDialer() *ProxyDialer {
return &ProxyDialer{}
}

View File

@@ -1,3 +1,4 @@
// Package server implements the olcrtc tunnel server logic.
package server
import (
@@ -5,10 +6,12 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -21,7 +24,24 @@ import (
"github.com/pion/webrtc/v4"
)
type Server struct {
var (
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the encryption key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
// ErrNoPeers is returned when no peers are available.
ErrNoPeers = errors.New("no peers available")
// ErrDialProxy is returned when dialing the proxy fails.
ErrDialProxy = errors.New("failed to dial proxy")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
)
type Server struct { //nolint:revive
peers []*telemost.Peer
cipher *crypto.Cipher
mux *mux.Multiplexer
@@ -33,48 +53,32 @@ type Server struct {
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
dnsCache sync.Map
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
}
type ConnectRequest struct {
type ConnectRequest struct { //nolint:revive
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr string, socksProxyPort int) error {
// Run starts the olcrtc server and listens for client connections.
func Run(
ctx context.Context,
roomURL,
keyHex string,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
var key []byte
var err error
if keyHex == "" {
key = make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return err
}
log.Printf("Generated key: %x", key)
} else {
key, err = hex.DecodeString(keyHex)
if err != nil {
return err
}
if len(key) != 32 {
return fmt.Errorf("key must be 32 bytes, got %d", len(key))
}
}
keyStr := string(key)
if len(keyStr) != 32 {
return fmt.Errorf("key string length must be 32, got %d", len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
cipher, err := setupCipher(keyHex)
if err != nil {
return err
return fmt.Errorf("setupCipher failed: %w", err)
}
s := &Server{
@@ -87,20 +91,72 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr
socksProxyPort: socksProxyPort,
}
if dnsServer == "" {
dnsServer = "1.1.1.1:53"
if s.dnsServer == "" {
s.dnsServer = "1.1.1.1:53"
}
s.setupResolver()
s.setupMux()
const peerCount = 1
for i := range peerCount {
if err := s.addPeer(runCtx, roomURL, i, cancel); err != nil {
return fmt.Errorf("addPeer failed: %w", err)
}
}
err = s.runLoop(runCtx)
log.Println("Waiting for server goroutines...")
s.wg.Wait()
log.Println("Server goroutines finished")
return err
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
var key []byte
var err error
if keyHex == "" {
key = make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
log.Printf("Generated key: %x", key)
} else {
key, err = hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
}
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (s *Server) setupResolver() {
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
d := net.Dialer{Timeout: 3 * time.Second}
return d.DialContext(ctx, network, dnsServer)
return d.DialContext(ctx, network, s.dnsServer)
},
}
}
peerCount := 1
func (s *Server) setupMux() {
s.mux = mux.New(0, func(frame []byte) error {
for {
canSend := true
@@ -118,110 +174,118 @@ func Run(ctx context.Context, roomURL, keyHex string, dnsServer, socksProxyAddr
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return err
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
if len(s.peers) == 0 {
return ErrNoPeers
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
return s.peers[idx].Send(encrypted)
})
}
for i := 0; i < peerCount; i++ {
peerID := i
peer, err := telemost.NewPeer(roomURL, names.Generate(), s.onData)
if err != nil {
return err
}
peer.SetEndedCallback(func(reason string) {
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
cancel()
})
s.peers = append(s.peers, peer)
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
if dc == nil {
log.Printf("Server peer %d channel closed - resetting multiplexer state", peerID)
} else {
log.Printf("Server peer %d reconnected - resetting multiplexer state", peerID)
}
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
conn.Close()
}
delete(s.connections, sid)
}
s.connMu.Unlock()
if dc != nil {
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return err
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers))
return s.peers[idx].Send(encrypted)
})
}
s.mux.Reset()
log.Println("Server multiplexer reset complete")
})
peer.SetShouldReconnect(func() bool {
return s.activeClients.Load() > 0
})
log.Printf("Connecting peer %d to Telemost...", peerID)
if err := peer.Connect(runCtx); err != nil {
return err
}
log.Printf("Peer %d connected", peerID)
s.wg.Add(1)
go func() {
defer s.wg.Done()
peer.WatchConnection(runCtx)
}()
func (s *Server) addPeer(ctx context.Context, roomURL string, peerID int, cancel context.CancelFunc) error {
peer, err := telemost.NewPeer(ctx, roomURL, names.Generate(), s.onData)
if err != nil {
return fmt.Errorf("failed to create peer: %w", err)
}
err = s.run(runCtx)
peer.SetEndedCallback(func(reason string) {
log.Printf("Server peer %d reported conference end: %s", peerID, reason)
cancel()
})
s.peers = append(s.peers, peer)
log.Println("Waiting for server goroutines...")
s.wg.Wait()
log.Println("Server goroutines finished")
peer.SetReconnectCallback(func(dc *webrtc.DataChannel) {
s.handlePeerReconnect(peerID, dc)
})
return err
peer.SetShouldReconnect(func() bool {
return s.activeClients.Load() > 0
})
log.Printf("Connecting peer %d to Telemost...", peerID)
if err := peer.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect peer: %w", err)
}
log.Printf("Peer %d connected", peerID)
s.wg.Add(1)
go func() {
defer s.wg.Done()
peer.WatchConnection(ctx)
}()
return nil
}
func (s *Server) handlePeerReconnect(peerID int, dc *webrtc.DataChannel) {
if dc == nil {
log.Printf("Server peer %d channel closed - resetting mux state", peerID)
} else {
log.Printf("Server peer %d reconnected - resetting mux state", peerID)
}
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
delete(s.connections, sid)
}
s.connMu.Unlock()
if dc != nil {
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.peers) == 0 {
return ErrNoPeers
}
idx := s.peerIdx.Add(1) % uint32(len(s.peers)) //nolint:gosec
return s.peers[idx].Send(encrypted)
})
}
s.mux.Reset()
log.Println("Server multiplexer reset complete")
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return err
return fmt.Errorf("failed to write socks5 auth: %w", err)
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 auth failed")
return ErrSocks5AuthFailed
}
req := []byte{5, 1, 0, 3}
req = append(req, byte(len(targetAddr)))
addrLen := len(targetAddr)
if addrLen > 255 {
addrLen = 255
targetAddr = targetAddr[:255]
}
req := make([]byte, 0, 7+addrLen)
req = append(req, 5, 1, 0, 3, byte(addrLen))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort))
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec
if _, err := conn.Write(req); err != nil {
return err
return fmt.Errorf("failed to write socks5 connect req: %w", err)
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return err
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("SOCKS5 connect failed: %d", resp[1])
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
}
return nil
@@ -230,12 +294,12 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int)
func (s *Server) onData(data []byte) {
plaintext, err := s.cipher.Decrypt(data)
if err != nil {
logger.Debug("Decrypt error: %v", err)
logger.Debugf("Decrypt error: %v", err)
return
}
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
log.Printf("Received reset signal from client (clientID=%d) - cleaning up", control.ClientID)
log.Printf("Received reset signal from client (clientID=%d)", control.ClientID)
s.closeClientConnections(control.ClientID)
}
@@ -250,63 +314,67 @@ func (s *Server) closeClientConnections(clientID uint32) {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
conn.Close()
_ = conn.Close()
}
delete(s.connections, streamSid)
}
}
}
func (s *Server) run(ctx context.Context) error {
func (s *Server) runLoop(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Println("Server shutting down...")
s.connMu.Lock()
for _, conn := range s.connections {
if conn != nil {
conn.Close()
}
}
s.connMu.Unlock()
log.Printf("Closing %d peer(s)...", len(s.peers))
for i, peer := range s.peers {
log.Printf("Closing peer %d...", i)
peer.Close()
}
log.Println("All peers closed")
s.shutdown()
return nil
case <-ticker.C:
s.processMuxStreams(ctx)
}
sids := s.mux.GetStreams()
}
}
for _, sid := range sids {
if s.mux.StreamClosed(sid) {
s.closeStreamConnection(sid)
continue
}
func (s *Server) shutdown() {
log.Println("Server shutting down...")
s.connMu.Lock()
for _, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
}
s.connMu.Unlock()
if s.hasConnection(sid) {
continue
}
for i, peer := range s.peers {
log.Printf("Closing peer %d...", i)
_ = peer.Close()
}
log.Println("All peers closed")
}
data := s.mux.ReadStream(sid)
if len(data) == 0 {
continue
}
func (s *Server) processMuxStreams(ctx context.Context) {
sids := s.mux.GetStreams()
for _, sid := range sids {
if s.mux.StreamClosed(sid) {
s.closeStreamConnection(sid)
continue
}
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("[SERVER] sid=%d RECEIVED_CONNECT_REQUEST %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
if s.hasConnection(sid) {
continue
}
data := s.mux.ReadStream(sid)
if len(data) == 0 {
continue
}
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
log.Printf("[SERVER] sid=%d RECV_CONNECT %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
}
}
@@ -314,15 +382,14 @@ func (s *Server) run(ctx context.Context) error {
func (s *Server) hasConnection(sid uint16) bool {
s.connMu.RLock()
defer s.connMu.RUnlock()
conn := s.connections[sid]
return conn != nil
return s.connections[sid] != nil
}
func (s *Server) closeStreamConnection(sid uint16) {
s.connMu.Lock()
conn := s.connections[sid]
if conn != nil {
conn.Close()
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
@@ -332,7 +399,7 @@ func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
s.connMu.Lock()
conn := s.connections[sid]
if conn == expected {
conn.Close()
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
@@ -344,7 +411,7 @@ func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
if current := s.streamPumps[sid]; current == conn {
return false
} else if current != nil {
current.Close()
_ = current.Close()
}
s.streamPumps[sid] = conn
return true
@@ -360,102 +427,103 @@ func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
startTime := time.Now()
addr := fmt.Sprintf("%s:%d", req.Addr, req.Port)
logger.Verbose("Handling connect request sid=%d to %s", sid, addr)
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
log.Printf("[SERVER] sid=%d CONNECT_START %s", sid, addr)
s.connMu.Lock()
oldConn, exists := s.connections[sid]
if exists && oldConn != nil {
log.Printf("Closing old connection for sid=%d", sid)
oldConn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
s.closeStreamConnection(sid)
dialStart := time.Now()
var conn net.Conn
var err error
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial=%v total=%v err=%v",
sid, dialElapsed, time.Since(startTime), err)
_ = s.mux.CloseStream(sid)
return
}
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial=%v", sid, dialElapsed)
s.activeClients.Add(1)
_ = s.mux.SendData(sid, []byte{0x00})
s.startStreamPump(ctx, sid, conn)
go s.pumpToMux(sid, conn)
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err = dialer.Dial("tcp4", addr)
logger.Verbose("TCP dial took %v for sid=%d (direct)", time.Since(dialStart), sid)
} else {
proxyAddr := fmt.Sprintf("%s:%d", s.socksProxyAddr, s.socksProxyPort)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
conn, err = dialer.Dial("tcp4", proxyAddr)
if err == nil {
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
conn.Close()
err = fmt.Errorf("SOCKS5 connect failed: %v", err)
}
}
logger.Verbose("SOCKS5 proxy dial took %v for sid=%d", time.Since(dialStart), sid)
return conn, nil
}
dialElapsed := time.Since(dialStart)
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
log.Printf("[SERVER] sid=%d CONNECT_FAILED dial_time=%v total_elapsed=%v err=%v", sid, dialElapsed, time.Since(startTime), err)
go s.mux.CloseStream(sid)
return
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
logger.Verbose("TCP dial took %v for sid=%d", dialElapsed, sid)
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
log.Printf("[SERVER] sid=%d CONNECT_SUCCESS dial_time=%v", sid, dialElapsed)
s.activeClients.Add(1)
s.mux.SendData(sid, []byte{0x00})
s.startStreamPump(ctx, sid, conn)
go func() {
defer func() {
s.activeClients.Add(-1)
s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
log.Printf("[SERVER] sid=%d TRANSFER_COMPLETE total=%d MB", sid, totalSent/(1024*1024))
}
return
}
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := s.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n)
if time.Since(lastLog) > 5*time.Second {
log.Printf("[SERVER] sid=%d TRANSFER_PROGRESS sent=%d MB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
s.activeClients.Add(-1)
_ = s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
lastLog := time.Now()
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
log.Printf("[SERVER] sid=%d TRANSFER_DONE total=%d MB", sid, totalSent/(1024*1024))
}
return
}
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := s.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n) //nolint:gosec
if time.Since(lastLog) > 5*time.Second {
log.Printf("[SERVER] sid=%d TRANSFER_UP sent=%d MB", sid, totalSent/(1024*1024))
lastLog = time.Now()
}
}
}
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
@@ -479,7 +547,7 @@ func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn)
data := s.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
s.mux.CloseStream(sid)
_ = s.mux.CloseStream(sid)
s.closeStreamConnectionIfCurrent(sid, conn)
return
}

View File

@@ -1,7 +1,9 @@
package telemost
package telemost //nolint:revive
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -13,21 +15,23 @@ import (
const apiBase = "https://cloud-api.yandex.ru/telemost_front/v2/telemost"
type ConnectionInfo struct {
RoomID string `json:"room_id"`
PeerID string `json:"peer_id"`
Credentials string `json:"credentials"`
var ErrAPI = errors.New("api error") //nolint:revive
type ConnectionInfo struct { //nolint:revive
RoomID string `json:"room_id"` //nolint:tagliatelle
PeerID string `json:"peer_id"` //nolint:tagliatelle
Credentials string `json:"credentials"` //nolint:tagliatelle
ClientConfig struct {
MediaServerURL string `json:"media_server_url"`
} `json:"client_configuration"`
MediaServerURL string `json:"media_server_url"` //nolint:tagliatelle
} `json:"client_configuration"` //nolint:tagliatelle
}
func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) {
func GetConnectionInfo(ctx context.Context, roomURL, displayName string) (*ConnectionInfo, error) { //nolint:revive
u := fmt.Sprintf("%s/conferences/%s/connection", apiBase, url.QueryEscape(roomURL))
req, err := http.NewRequest("GET", u, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create request: %w", err)
}
q := req.URL.Query()
@@ -48,18 +52,18 @@ func GetConnectionInfo(roomURL, displayName string) (*ConnectionInfo, error) {
client := protect.NewHTTPClient()
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to do request: %w", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, body)
return nil, fmt.Errorf("%w %d: %s", ErrAPI, resp.StatusCode, body)
}
var info ConnectionInfo
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return nil, err
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &info, nil

File diff suppressed because it is too large Load Diff