Files
olcrtc/internal/server/server.go
2026-04-28 19:18:23 +03:00

591 lines
14 KiB
Go

// Package server implements the olcrtc tunnel server logic.
package server
import (
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/link"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/mux"
"github.com/openlibrecommunity/olcrtc/internal/names"
)
var (
// ErrKeySize is returned when the encryption key is not 32 bytes.
ErrKeySize = errors.New("key must be 32 bytes")
// ErrKeyStringLength is returned when the encryption key string length is not 32.
ErrKeyStringLength = errors.New("key string length must be 32")
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
// ErrNoLinks is returned when no links are available.
ErrNoLinks = errors.New("no links available")
// ErrDialProxy is returned when dialing the proxy fails.
ErrDialProxy = errors.New("failed to dial proxy")
// ErrEncryptFailed is returned when encryption fails.
ErrEncryptFailed = errors.New("encrypt failed")
)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
links []link.Link
cipher *crypto.Cipher
mux *mux.Multiplexer
connections map[uint16]net.Conn
connMu sync.RWMutex
streamPumps map[uint16]net.Conn
pumpMu sync.Mutex
linkIdx atomic.Uint32
activeClients atomic.Int32
wg sync.WaitGroup
dnsServer string
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
}
// ConnectRequest is a message from the client to establish a new connection.
type ConnectRequest struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
// Run starts the server with the specified parameters.
func Run(
ctx context.Context,
linkName,
transportName,
carrierName,
roomURL,
keyHex string,
dnsServer,
socksProxyAddr string,
socksProxyPort int,
videoWidth int,
videoHeight int,
videoFPS int,
videoBitrate string,
videoHW string,
videoQRSize int,
videoQRRecovery string,
videoCodec string,
vp8FPS int,
vp8BatchSize int,
) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(keyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
s := &Server{
cipher: cipher,
connections: make(map[uint16]net.Conn),
streamPumps: make(map[uint16]net.Conn),
links: make([]link.Link, 0),
dnsServer: dnsServer,
socksProxyAddr: socksProxyAddr,
socksProxyPort: socksProxyPort,
}
s.setupResolver()
s.setupMux()
const linkCount = 1
for i := range linkCount {
if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, vp8FPS, vp8BatchSize); err != nil {
return fmt.Errorf("addLink failed: %w", err)
}
}
err = s.runLoop(runCtx)
s.shutdown()
s.wg.Wait()
return err
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
if keyHex == "" {
return nil, errors.New("key required (use -key <hex>)")
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("failed to decode key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
}
keyStr := string(key)
if len(keyStr) != 32 {
return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr))
}
cipher, err := crypto.NewCipher(keyStr)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
return cipher, nil
}
func (s *Server) setupResolver() {
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
d := net.Dialer{Timeout: 3 * time.Second}
return d.DialContext(ctx, network, s.dnsServer)
},
}
}
func (s *Server) setupMux() {
s.mux = mux.New(0, func(frame []byte) error {
for {
canSend := true
for _, ln := range s.links {
if !ln.CanSend() {
canSend = false
break
}
}
if canSend {
break
}
time.Sleep(10 * time.Millisecond)
}
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.links) == 0 {
return ErrNoLinks
}
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
return s.links[idx].Send(encrypted)
})
}
func (s *Server) addLink(
ctx context.Context,
linkName,
transportName,
carrierName,
roomURL string,
linkID int,
cancel context.CancelFunc,
videoWidth, videoHeight, videoFPS int,
videoBitrate, videoHW string,
videoQRSize int,
videoQRRecovery string,
videoCodec string,
vp8FPS int,
vp8BatchSize int,
) error {
ln, err := link.New(ctx, linkName, link.Config{
Transport: transportName,
Carrier: carrierName,
RoomURL: roomURL,
Name: names.Generate(),
OnData: s.onData,
DNSServer: s.dnsServer,
ProxyAddr: s.socksProxyAddr,
ProxyPort: s.socksProxyPort,
VideoWidth: videoWidth,
VideoHeight: videoHeight,
VideoFPS: videoFPS,
VideoBitrate: videoBitrate,
VideoHW: videoHW,
VideoQRSize: videoQRSize,
VideoQRRecovery: videoQRRecovery,
VideoCodec: videoCodec,
VP8FPS: vp8FPS,
VP8BatchSize: vp8BatchSize,
})
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
ln.SetEndedCallback(func(reason string) {
logger.Infof("Server link %d reported conference end: %s", linkID, reason)
cancel()
})
s.links = append(s.links, ln)
ln.SetReconnectCallback(func() {
s.handleLinkReconnect(linkID)
})
logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName)
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
logger.Infof("Link %d connected", linkID)
s.wg.Add(1)
go func() {
defer s.wg.Done()
ln.WatchConnection(ctx)
}()
return nil
}
func (s *Server) handleLinkReconnect(linkID int) {
logger.Infof("link %d reconnect event", linkID)
s.connMu.Lock()
for sid, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
delete(s.connections, sid)
}
s.connMu.Unlock()
s.mux.UpdateSendFunc(func(frame []byte) error {
encrypted, err := s.cipher.Encrypt(frame)
if err != nil {
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
}
if len(s.links) == 0 {
return ErrNoLinks
}
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
return s.links[idx].Send(encrypted)
})
s.mux.Reset()
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return fmt.Errorf("failed to write socks5 auth: %w", err)
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return ErrSocks5AuthFailed
}
addrLen := len(targetAddr)
if addrLen > 255 {
addrLen = 255
targetAddr = targetAddr[:255]
}
req := make([]byte, 0, 7+addrLen)
req = append(req, 5, 1, 0, 3, byte(addrLen))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec
if _, err := conn.Write(req); err != nil {
return fmt.Errorf("failed to write socks5 connect req: %w", err)
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
}
return nil
}
func (s *Server) onData(data []byte) {
plaintext, err := s.cipher.Decrypt(data)
if err != nil {
logger.Debugf("Decrypt error: %v", err)
return
}
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID)
s.closeClientConnections(control.ClientID)
}
s.mux.HandleFrame(plaintext)
}
func (s *Server) closeClientConnections(clientID uint32) {
s.connMu.Lock()
defer s.connMu.Unlock()
for streamSid, conn := range s.connections {
stream := s.mux.GetStream(streamSid)
if stream != nil && stream.ClientID == clientID {
if conn != nil {
_ = conn.Close()
}
delete(s.connections, streamSid)
}
}
}
func (s *Server) runLoop(ctx context.Context) error {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
s.processMuxStreams(ctx)
}
}
}
func (s *Server) shutdown() {
s.connMu.Lock()
for _, conn := range s.connections {
if conn != nil {
_ = conn.Close()
}
}
s.connMu.Unlock()
s.pumpMu.Lock()
for _, conn := range s.streamPumps {
if conn != nil {
_ = conn.Close()
}
}
s.pumpMu.Unlock()
for i, tr := range s.links {
logger.Infof("closing link %d", i)
_ = tr.Close()
}
}
func (s *Server) processMuxStreams(ctx context.Context) {
sids := s.mux.GetStreams()
for _, sid := range sids {
if s.mux.StreamClosed(sid) {
s.closeStreamConnection(sid)
continue
}
if s.hasConnection(sid) {
continue
}
data := s.mux.ReadStream(sid)
if len(data) == 0 {
continue
}
var req ConnectRequest
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port)
s.closeStreamConnection(sid)
go s.handleConnect(ctx, sid, req)
}
}
}
func (s *Server) hasConnection(sid uint16) bool {
s.connMu.RLock()
defer s.connMu.RUnlock()
return s.connections[sid] != nil
}
func (s *Server) closeStreamConnection(sid uint16) {
s.connMu.Lock()
conn := s.connections[sid]
if conn != nil {
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
s.connMu.Lock()
conn := s.connections[sid]
if conn == expected {
_ = conn.Close()
delete(s.connections, sid)
}
s.connMu.Unlock()
}
func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
s.pumpMu.Lock()
defer s.pumpMu.Unlock()
if current := s.streamPumps[sid]; current == conn {
return false
} else if current != nil {
_ = current.Close()
}
s.streamPumps[sid] = conn
return true
}
func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
s.pumpMu.Lock()
if s.streamPumps[sid] == conn {
delete(s.streamPumps, sid)
}
s.pumpMu.Unlock()
}
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
s.closeStreamConnection(sid)
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
_ = s.mux.CloseStream(sid)
return
}
s.connMu.Lock()
s.connections[sid] = conn
s.connMu.Unlock()
logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed)
s.activeClients.Add(1)
_ = s.mux.SendData(sid, []byte{0x00})
s.startStreamPump(ctx, sid, conn)
go s.pumpToMux(sid, conn)
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
defer func() {
s.activeClients.Add(-1)
_ = s.mux.CloseStream(sid)
s.connMu.Lock()
delete(s.connections, sid)
s.connMu.Unlock()
}()
buf := make([]byte, 16384)
totalSent := uint64(0)
for {
n, err := conn.Read(buf)
if err != nil {
if totalSent > 1024*1024 {
logger.Infof("sid=%d done total=%dMB", sid, totalSent/(1024*1024))
}
return
}
for !s.canSendData() {
time.Sleep(20 * time.Millisecond)
}
if err := s.mux.SendData(sid, buf[:n]); err != nil {
return
}
totalSent += uint64(n) //nolint:gosec
}
}
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
if !s.markStreamPump(sid, conn) {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.unmarkStreamPump(sid, conn)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
data := s.mux.ReadStream(sid)
if len(data) > 0 {
if _, err := conn.Write(data); err != nil {
_ = s.mux.CloseStream(sid)
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
if s.mux.StreamClosed(sid) {
s.closeStreamConnectionIfCurrent(sid, conn)
return
}
}
}
}()
}
func (s *Server) canSendData() bool {
for _, tr := range s.links {
if !tr.CanSend() {
return false
}
}
return true
}