mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-27 23:49:44 +00:00
Merge branch 'transport/smux' into transport/videochannel
This commit is contained in:
1
go.mod
1
go.mod
@@ -68,6 +68,7 @@ require (
|
||||
github.com/tjfoc/gmsm v1.4.1 // indirect
|
||||
github.com/twitchtv/twirp v8.1.3+incompatible // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/xtaci/smux v1.5.57 // indirect
|
||||
github.com/zeebo/xxh3 v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel v1.40.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -226,6 +226,8 @@ github.com/xtaci/kcp-go/v5 v5.6.72 h1:FLaQPalgpufJYQRk0OK+gErEhXGLUPjv6FSRPrFR8L
|
||||
github.com/xtaci/kcp-go/v5 v5.6.72/go.mod h1:9O3D8WR+cyyUjGiTILYfg17vn72otWuXK2AFfqIe6CM=
|
||||
github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM=
|
||||
github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE=
|
||||
github.com/xtaci/smux v1.5.57 h1:N72VbGoSYxgcm6mPOYX0QzEZNVD3UI/JlVvAtXF+WrY=
|
||||
github.com/xtaci/smux v1.5.57/go.mod h1:IGQ9QYrBphmb/4aTnLEcJby0TNr3NV+OslIOMrX825Q=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89 h1:ytA0RfQZTYfjqFA9lBJMX1DTnXpTuKg0nf4udgdpunE=
|
||||
github.com/zarazaex69/b v0.0.0-20260423064626-c0bd20863b89/go.mod h1:OUqzZNoXsg+ccaiAnSe0t4f8qc0W/cFx6io0lWsE1Gw=
|
||||
|
||||
@@ -3,7 +3,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -17,8 +16,9 @@ import (
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/mux"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -26,21 +26,16 @@ var (
|
||||
ErrConnectFailed = errors.New("tunnel connection failed")
|
||||
// ErrProxyAuth is returned when SOCKS proxy authentication fails.
|
||||
ErrProxyAuth = errors.New("SOCKS proxy auth failed")
|
||||
// ErrMuxExited is returned when the multiplexer loop exits unexpectedly.
|
||||
ErrMuxExited = errors.New("multiplexer loop exited")
|
||||
// ErrNoAvailableLinks is returned when no links are ready for sending.
|
||||
ErrNoAvailableLinks = errors.New("no available links")
|
||||
)
|
||||
|
||||
// Client handles local SOCKS5 connections and tunnels them to the server.
|
||||
type Client struct {
|
||||
links []link.Link
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
connections map[uint16]net.Conn
|
||||
connMu sync.RWMutex
|
||||
clientID uint32
|
||||
dnsServer string
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
sessMu sync.RWMutex
|
||||
dnsServer string
|
||||
}
|
||||
|
||||
// Run starts the client with the specified parameters.
|
||||
@@ -105,37 +100,27 @@ func RunWithReady(
|
||||
return fmt.Errorf("setupCipher failed: %w", err)
|
||||
}
|
||||
|
||||
clientIDBytes := make([]byte, 4)
|
||||
if _, err := rand.Read(clientIDBytes); err != nil {
|
||||
return fmt.Errorf("failed to generate client ID: %w", err)
|
||||
}
|
||||
clientID := binary.BigEndian.Uint32(clientIDBytes)
|
||||
c := &Client{cipher: cipher, dnsServer: dnsServer}
|
||||
|
||||
c := &Client{
|
||||
cipher: cipher,
|
||||
connections: make(map[uint16]net.Conn),
|
||||
links: make([]link.Link, 0),
|
||||
clientID: clientID,
|
||||
dnsServer: dnsServer,
|
||||
}
|
||||
|
||||
c.setupMux()
|
||||
|
||||
const linkCount = 1
|
||||
for i := range linkCount {
|
||||
if err := c.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, dnsServer, "", 0, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil {
|
||||
return fmt.Errorf("addLink failed: %w", err)
|
||||
}
|
||||
if err := c.bringUpLink(
|
||||
runCtx, linkName, transportName, carrierName, roomURL, cancel,
|
||||
dnsServer, "", 0,
|
||||
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
|
||||
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
|
||||
vp8FPS, vp8BatchSize,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.shutdown()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(runCtx, "tcp4", localAddr)
|
||||
listener, err := lc.Listen(runCtx, "tcp4", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
|
||||
}
|
||||
defer ln.Close()
|
||||
defer listener.Close()
|
||||
|
||||
logger.Infof("SOCKS5 server listening on %s (ClientID: %d)", localAddr, clientID)
|
||||
logger.Infof("SOCKS5 server listening on %s", localAddr)
|
||||
|
||||
if onReady != nil {
|
||||
onReady()
|
||||
@@ -143,96 +128,30 @@ func RunWithReady(
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- c.acceptLoop(runCtx, ln)
|
||||
errCh <- c.acceptLoop(runCtx, listener)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-runCtx.Done():
|
||||
c.shutdown()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.connMu.Lock()
|
||||
for _, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
for i, ln := range c.links {
|
||||
logger.Infof("closing link %d", i)
|
||||
_ = ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
func (c *Client) setupMux() {
|
||||
c.mux = mux.New(c.clientID, func(frame []byte) error {
|
||||
for {
|
||||
canSend := true
|
||||
for _, ln := range c.links {
|
||||
if !ln.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canSend {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(c.links) == 0 {
|
||||
return ErrNoAvailableLinks
|
||||
}
|
||||
return c.links[0].Send(encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) addLink(
|
||||
func (c *Client) bringUpLink(
|
||||
ctx context.Context,
|
||||
linkName,
|
||||
transportName,
|
||||
carrierName,
|
||||
roomURL string,
|
||||
linkID int,
|
||||
linkName, transportName, carrierName, roomURL string,
|
||||
cancel context.CancelFunc,
|
||||
dnsServer,
|
||||
socksProxyAddr string,
|
||||
dnsServer, socksProxyAddr string,
|
||||
socksProxyPort int,
|
||||
videoWidth, videoHeight, videoFPS int,
|
||||
videoBitrate, videoHW string,
|
||||
videoQRSize int,
|
||||
videoQRRecovery string,
|
||||
videoCodec string,
|
||||
videoTileModule int,
|
||||
videoTileRS int,
|
||||
vp8FPS int,
|
||||
vp8BatchSize int,
|
||||
videoTileModule, videoTileRS int,
|
||||
vp8FPS, vp8BatchSize int,
|
||||
) error {
|
||||
ln, err := link.New(ctx, linkName, link.Config{
|
||||
Transport: transportName,
|
||||
@@ -259,56 +178,104 @@ func (c *Client) addLink(
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
}
|
||||
c.ln = ln
|
||||
|
||||
ln.SetEndedCallback(func(reason string) {
|
||||
logger.Infof("Client link %d reported conference end: %s", linkID, reason)
|
||||
logger.Infof("Client link reported conference end: %s", reason)
|
||||
cancel()
|
||||
})
|
||||
c.links = append(c.links, ln)
|
||||
|
||||
ln.SetReconnectCallback(func() {
|
||||
c.handleLinkReconnect(linkID)
|
||||
})
|
||||
ln.SetReconnectCallback(func() { c.handleReconnect() })
|
||||
|
||||
if err := ln.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect link: %w", err)
|
||||
}
|
||||
|
||||
c.conn = muxconn.New(ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("smux client: %w", err)
|
||||
}
|
||||
c.sessMu.Lock()
|
||||
c.session = sess
|
||||
c.sessMu.Unlock()
|
||||
|
||||
go ln.WatchConnection(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleLinkReconnect(linkID int) {
|
||||
logger.Infof("link %d reconnect event", linkID)
|
||||
c.sendResetSignal()
|
||||
|
||||
c.connMu.Lock()
|
||||
for sid, conn := range c.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(c.connections, sid)
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
c.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := c.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(c.links) == 0 {
|
||||
return ErrNoAvailableLinks
|
||||
}
|
||||
return c.links[0].Send(encrypted)
|
||||
})
|
||||
c.mux.Reset()
|
||||
// smuxConfig returns the tuned smux config used on both ends.
|
||||
func smuxConfig() *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.MaxFrameSize = 32768
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
cfg.KeepAliveTimeout = 60 * time.Second
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Client) sendResetSignal() {
|
||||
resetFrame := mux.BuildControlFrame(c.clientID, mux.ControlResetClient)
|
||||
encrypted, _ := c.cipher.Encrypt(resetFrame)
|
||||
if len(c.links) > 0 {
|
||||
_ = c.links[0].Send(encrypted)
|
||||
func (c *Client) handleReconnect() {
|
||||
logger.Infof("client link reconnect — tearing down smux session")
|
||||
c.sessMu.Lock()
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
c.session = nil
|
||||
}
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.sessMu.Unlock()
|
||||
// New SOCKS5 connections will fail until the link comes back up; the
|
||||
// caller will reissue them. Existing streams die with the smux session.
|
||||
c.conn = muxconn.New(c.ln, c.cipher)
|
||||
sess, err := smux.Client(c.conn, smuxConfig())
|
||||
if err != nil {
|
||||
logger.Warnf("smux re-init failed: %v", err)
|
||||
return
|
||||
}
|
||||
c.sessMu.Lock()
|
||||
c.session = sess
|
||||
c.sessMu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) shutdown() {
|
||||
c.sessMu.Lock()
|
||||
if c.session != nil {
|
||||
_ = c.session.Close()
|
||||
}
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
}
|
||||
c.sessMu.Unlock()
|
||||
if c.ln != nil {
|
||||
_ = c.ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("key must be 32 bytes, got %d", len(key))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
return cipher, nil
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
c.sessMu.RLock()
|
||||
conn := c.conn
|
||||
c.sessMu.RUnlock()
|
||||
if conn != nil {
|
||||
conn.Push(data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,19 +307,23 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
sid := c.mux.OpenStream()
|
||||
defer c.mux.CloseStream(sid)
|
||||
c.sessMu.RLock()
|
||||
sess := c.session
|
||||
c.sessMu.RUnlock()
|
||||
if sess == nil || sess.IsClosed() {
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
|
||||
c.connMu.Lock()
|
||||
c.connections[sid] = conn
|
||||
c.connMu.Unlock()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
delete(c.connections, sid)
|
||||
c.connMu.Unlock()
|
||||
}()
|
||||
stream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
logger.Warnf("OpenStream failed: %v", err)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
logger.Infof("sid=%d tunnel to %s:%d", sid, targetAddr, targetPort)
|
||||
logger.Infof("sid=%d tunnel to %s:%d", stream.ID(), targetAddr, targetPort)
|
||||
|
||||
connectReq, _ := json.Marshal(map[string]any{
|
||||
"cmd": "connect",
|
||||
@@ -360,45 +331,34 @@ func (c *Client) handleSocks5(ctx context.Context, conn net.Conn) {
|
||||
"port": targetPort,
|
||||
})
|
||||
|
||||
if err := c.mux.SendData(sid, connectReq); err != nil {
|
||||
logger.Warnf("sid=%d tunnel setup failed: %v", sid, err)
|
||||
_ = stream.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if _, err := stream.Write(connectReq); err != nil {
|
||||
logger.Warnf("sid=%d connect req failed: %v", stream.ID(), err)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
_ = stream.SetWriteDeadline(time.Time{})
|
||||
|
||||
readyTimer := time.NewTimer(10 * time.Second)
|
||||
defer readyTimer.Stop()
|
||||
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
|
||||
var initialData []byte
|
||||
select {
|
||||
case <-readyTimer.C:
|
||||
logger.Warnf("sid=%d tunnel setup failed: timeout waiting for remote ready", sid)
|
||||
ack := make([]byte, 1)
|
||||
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := io.ReadFull(stream, ack); err != nil || ack[0] != 0x00 {
|
||||
logger.Warnf("sid=%d remote ready failed: err=%v ack=%v", stream.ID(), err, ack)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
case <-dataReady:
|
||||
initialData = c.mux.ReadStream(sid)
|
||||
if len(initialData) == 0 || initialData[0] != 0x00 {
|
||||
logger.Warnf("sid=%d tunnel setup failed: invalid remote ready", sid)
|
||||
_, _ = conn.Write(replyHostUnreachable())
|
||||
return
|
||||
}
|
||||
}
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
|
||||
if _, err := conn.Write(replySuccess()); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle the rest of initialData if any (unlikely for 0x00 packet)
|
||||
if len(initialData) > 1 {
|
||||
if _, err := conn.Write(initialData[1:]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
_, _ = io.Copy(stream, conn)
|
||||
_ = stream.Close()
|
||||
}()
|
||||
_, _ = io.Copy(conn, stream)
|
||||
|
||||
go c.pumpFromMux(ctx, sid, conn)
|
||||
c.pumpToMux(sid, conn)
|
||||
_ = ctx // keep signature
|
||||
}
|
||||
|
||||
func (c *Client) socks5Handshake(conn net.Conn) error {
|
||||
@@ -459,62 +419,6 @@ func (c *Client) socks5Request(conn net.Conn) (string, int, error) {
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (c *Client) pumpToMux(sid uint16, conn net.Conn) {
|
||||
buf := make([]byte, 16384)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for !c.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := c.mux.SendData(sid, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) pumpFromMux(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
defer c.mux.CleanupDataChannel(sid)
|
||||
dataReady := c.mux.WaitForData(sid)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-dataReady:
|
||||
data := c.mux.ReadStream(sid)
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if c.mux.StreamClosed(sid) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) onData(data []byte) {
|
||||
plaintext, err := c.cipher.Decrypt(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.mux.HandleFrame(plaintext)
|
||||
}
|
||||
|
||||
func (c *Client) canSendData() bool {
|
||||
for _, tr := range c.links {
|
||||
if !tr.CanSend() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func replySuccess() []byte {
|
||||
return []byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}
|
||||
}
|
||||
|
||||
@@ -1,477 +0,0 @@
|
||||
// Package mux provides a multiplexer for multiple streams over a single connection.
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrClientResetID is returned when a client reset is attempted with a zero client ID.
|
||||
ErrClientResetID = errors.New("client reset requires a non-zero client id")
|
||||
// ErrDataTooLarge is returned when a data chunk exceeds the maximum frame size.
|
||||
ErrDataTooLarge = errors.New("data chunk too large")
|
||||
)
|
||||
|
||||
const (
|
||||
// HeaderSize is the size of the frame header in bytes.
|
||||
HeaderSize = 12
|
||||
|
||||
// ControlStreamID is a special stream ID used for control frames.
|
||||
ControlStreamID uint16 = 0xFFFF
|
||||
|
||||
// ControlResetClient is a control frame type used to signal a client reset.
|
||||
ControlResetClient uint32 = 1
|
||||
|
||||
// FrameTypeData is a marker for data frames.
|
||||
FrameTypeData uint16 = 0
|
||||
// FrameTypeControl is a marker for control frames.
|
||||
FrameTypeControl uint16 = 0xFFFF
|
||||
)
|
||||
|
||||
// ControlFrame represents a control message between multiplexers.
|
||||
type ControlFrame struct {
|
||||
ClientID uint32
|
||||
Type uint32
|
||||
}
|
||||
|
||||
// Stream represents a single multiplexed data stream.
|
||||
type Stream struct {
|
||||
ID uint16
|
||||
ClientID uint32
|
||||
recvBuf []byte
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
nextSeq uint32
|
||||
outOfOrder map[uint32][]byte
|
||||
}
|
||||
|
||||
// RecvBuf returns the current receive buffer content.
|
||||
func (s *Stream) RecvBuf() []byte {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.recvBuf
|
||||
}
|
||||
|
||||
// Multiplexer coordinates multiple Streams over a single transport channel.
|
||||
type Multiplexer struct {
|
||||
streams map[uint16]*Stream
|
||||
nextID uint16
|
||||
clientID uint32
|
||||
onSend func([]byte) error
|
||||
mu sync.RWMutex
|
||||
maxStreams int
|
||||
maxBufferSize int
|
||||
dataReady map[uint16]chan struct{}
|
||||
dataReadyMu sync.Mutex
|
||||
sendSeq map[uint16]uint32
|
||||
sendSeqMu sync.Mutex
|
||||
|
||||
// bufferCond is used to wait for space in receive buffers
|
||||
bufferCond *sync.Cond
|
||||
}
|
||||
|
||||
// New creates a new Multiplexer instance.
|
||||
func New(clientID uint32, onSend func([]byte) error) *Multiplexer {
|
||||
m := &Multiplexer{
|
||||
streams: make(map[uint16]*Stream),
|
||||
nextID: 1,
|
||||
clientID: clientID,
|
||||
onSend: onSend,
|
||||
maxStreams: 10000,
|
||||
maxBufferSize: 32 * 1024 * 1024,
|
||||
dataReady: make(map[uint16]chan struct{}),
|
||||
sendSeq: make(map[uint16]uint32),
|
||||
}
|
||||
m.bufferCond = sync.NewCond(&m.mu)
|
||||
return m
|
||||
}
|
||||
|
||||
// OpenStream allocates and returns a new unique stream ID.
|
||||
func (m *Multiplexer) OpenStream() uint16 {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for {
|
||||
sid := m.nextID
|
||||
m.nextID++
|
||||
if m.nextID == 0 {
|
||||
m.nextID = 1
|
||||
}
|
||||
|
||||
if _, exists := m.streams[sid]; !exists {
|
||||
m.streams[sid] = &Stream{
|
||||
ID: sid,
|
||||
recvBuf: make([]byte, 0),
|
||||
nextSeq: 0,
|
||||
outOfOrder: make(map[uint32][]byte),
|
||||
}
|
||||
return sid
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendData fragments and sends data over a specific stream.
|
||||
func (m *Multiplexer) SendData(sid uint16, data []byte) error {
|
||||
m.mu.RLock()
|
||||
stream, exists := m.streams[sid]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || stream.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
const chunkSize = 7000
|
||||
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
|
||||
chunk := data[i:end]
|
||||
|
||||
m.sendSeqMu.Lock()
|
||||
seq := m.sendSeq[sid]
|
||||
m.sendSeq[sid]++
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
if len(chunk) > math.MaxUint16 {
|
||||
return ErrDataTooLarge
|
||||
}
|
||||
|
||||
frame := make([]byte, HeaderSize+len(chunk))
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], uint16(len(chunk))) //nolint:gosec // Length checked above
|
||||
binary.BigEndian.PutUint32(frame[8:12], seq)
|
||||
copy(frame[HeaderSize:], chunk)
|
||||
|
||||
if err := m.onSend(frame); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream signals that a stream should be terminated.
|
||||
func (m *Multiplexer) CloseStream(sid uint16) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if stream, exists := m.streams[sid]; exists {
|
||||
stream.closed = true
|
||||
}
|
||||
|
||||
m.sendSeqMu.Lock()
|
||||
delete(m.sendSeq, sid)
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
// Notify anyone waiting for buffer space that a stream is closed
|
||||
m.bufferCond.Broadcast()
|
||||
|
||||
frame := make([]byte, HeaderSize)
|
||||
binary.BigEndian.PutUint32(frame[0:4], m.clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], sid)
|
||||
binary.BigEndian.PutUint16(frame[6:8], 0)
|
||||
binary.BigEndian.PutUint32(frame[8:12], 0)
|
||||
|
||||
if err := m.onSend(frame); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendClientReset sends a control frame to reset all streams for this client.
|
||||
func (m *Multiplexer) SendClientReset() error {
|
||||
if m.clientID == 0 {
|
||||
return ErrClientResetID
|
||||
}
|
||||
if err := m.onSend(BuildControlFrame(m.clientID, ControlResetClient)); err != nil {
|
||||
return fmt.Errorf("onSend failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildControlFrame constructs a raw control frame.
|
||||
func BuildControlFrame(clientID uint32, controlType uint32) []byte {
|
||||
frame := make([]byte, HeaderSize)
|
||||
binary.BigEndian.PutUint32(frame[0:4], clientID)
|
||||
binary.BigEndian.PutUint16(frame[4:6], ControlStreamID)
|
||||
binary.BigEndian.PutUint16(frame[6:8], 0xFFFF) // Use 0xFFFF as a marker for control
|
||||
binary.BigEndian.PutUint32(frame[8:12], controlType)
|
||||
return frame
|
||||
}
|
||||
|
||||
// ParseControlFrame attempts to extract control information from a frame.
|
||||
func ParseControlFrame(frame []byte) (ControlFrame, bool) {
|
||||
if len(frame) < HeaderSize {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
sid := binary.BigEndian.Uint16(frame[4:6])
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
if sid != ControlStreamID || length != 0xFFFF {
|
||||
return ControlFrame{}, false
|
||||
}
|
||||
|
||||
return ControlFrame{
|
||||
ClientID: binary.BigEndian.Uint32(frame[0:4]),
|
||||
Type: binary.BigEndian.Uint32(frame[8:12]),
|
||||
}, true
|
||||
}
|
||||
|
||||
// HandleFrame processes an incoming frame from the transport.
|
||||
func (m *Multiplexer) HandleFrame(frame []byte) {
|
||||
control, ok := ParseControlFrame(frame)
|
||||
if ok {
|
||||
m.handleControlFrame(control)
|
||||
return
|
||||
}
|
||||
|
||||
if len(frame) < HeaderSize {
|
||||
return
|
||||
}
|
||||
|
||||
clientID := binary.BigEndian.Uint32(frame[0:4])
|
||||
sid := binary.BigEndian.Uint16(frame[4:6])
|
||||
length := binary.BigEndian.Uint16(frame[6:8])
|
||||
seq := binary.BigEndian.Uint32(frame[8:12])
|
||||
|
||||
if length == 0 {
|
||||
m.handleCloseStreamFrame(sid, clientID)
|
||||
return
|
||||
}
|
||||
|
||||
if len(frame) < HeaderSize+int(length) {
|
||||
return
|
||||
}
|
||||
|
||||
m.processDataFrame(sid, clientID, seq, frame[HeaderSize:HeaderSize+int(length)])
|
||||
}
|
||||
|
||||
func (m *Multiplexer) handleCloseStreamFrame(sid uint16, clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if stream, exists := m.streams[sid]; exists && stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) processDataFrame(sid uint16, clientID uint32, seq uint32, data []byte) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stream := m.getOrCreateStream(sid, clientID)
|
||||
if stream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if seq == stream.nextSeq {
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(data)); s != nil {
|
||||
s.recvBuf = append(s.recvBuf, data...)
|
||||
s.nextSeq++
|
||||
m.applyOutOfOrder(s, sid, clientID)
|
||||
m.notifyDataReady(sid)
|
||||
}
|
||||
} else if seq > stream.nextSeq {
|
||||
if len(stream.outOfOrder) < 100 {
|
||||
stream.outOfOrder[seq] = append([]byte(nil), data...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) getOrCreateStream(sid uint16, clientID uint32) *Stream {
|
||||
stream, exists := m.streams[sid]
|
||||
if !exists {
|
||||
if len(m.streams) >= m.maxStreams {
|
||||
return nil
|
||||
}
|
||||
stream = &Stream{
|
||||
ID: sid,
|
||||
ClientID: clientID,
|
||||
recvBuf: make([]byte, 0),
|
||||
nextSeq: 0,
|
||||
outOfOrder: make(map[uint32][]byte),
|
||||
}
|
||||
m.streams[sid] = stream
|
||||
return stream
|
||||
}
|
||||
|
||||
if stream.ClientID != clientID {
|
||||
stream.ClientID = clientID
|
||||
stream.recvBuf = make([]byte, 0)
|
||||
stream.closed = false
|
||||
stream.nextSeq = 0
|
||||
stream.outOfOrder = make(map[uint32][]byte)
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
return stream
|
||||
}
|
||||
|
||||
func (m *Multiplexer) applyOutOfOrder(stream *Stream, sid uint16, clientID uint32) {
|
||||
for {
|
||||
nextData, ok := stream.outOfOrder[stream.nextSeq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if s := m.waitForBufferSpace(sid, clientID, len(nextData)); s == nil {
|
||||
return
|
||||
}
|
||||
stream.recvBuf = append(stream.recvBuf, nextData...)
|
||||
delete(stream.outOfOrder, stream.nextSeq)
|
||||
stream.nextSeq++
|
||||
logger.Verbosef("Applied out-of-order packet sid=%d seq=%d", sid, stream.nextSeq-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) notifyDataReady(sid uint16) {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
if ch, ok := m.dataReady[sid]; ok {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) handleControlFrame(control ControlFrame) {
|
||||
switch control.Type {
|
||||
case ControlResetClient:
|
||||
m.ResetClient(control.ClientID)
|
||||
default:
|
||||
logger.Debugf("Unknown mux control frame type=%d clientID=%d", control.Type, control.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetClient closes and removes all streams associated with a client ID.
|
||||
func (m *Multiplexer) ResetClient(clientID uint32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for streamSid, stream := range m.streams {
|
||||
if stream.ClientID == clientID {
|
||||
stream.closed = true
|
||||
delete(m.streams, streamSid)
|
||||
}
|
||||
}
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
|
||||
func (m *Multiplexer) waitForBufferSpace(sid uint16, clientID uint32, need int) *Stream {
|
||||
for {
|
||||
stream, ok := m.streams[sid]
|
||||
if !ok || stream.ClientID != clientID || stream.closed {
|
||||
return nil
|
||||
}
|
||||
if len(stream.recvBuf)+need <= m.maxBufferSize {
|
||||
return stream
|
||||
}
|
||||
// Wait for space to become available
|
||||
m.bufferCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStream retrieves and clears the current receive buffer for a stream.
|
||||
func (m *Multiplexer) ReadStream(sid uint16) []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stream, exists := m.streams[sid]
|
||||
if !exists || len(stream.recvBuf) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := stream.recvBuf
|
||||
stream.recvBuf = make([]byte, 0)
|
||||
|
||||
// Notify producers that space is now available
|
||||
m.bufferCond.Broadcast()
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// StreamClosed returns true if the stream is closed or doesn't exist.
|
||||
func (m *Multiplexer) StreamClosed(sid uint16) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
stream, exists := m.streams[sid]
|
||||
return !exists || stream.closed
|
||||
}
|
||||
|
||||
// GetStreams returns a list of all active stream IDs.
|
||||
func (m *Multiplexer) GetStreams() []uint16 {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sids := make([]uint16, 0, len(m.streams))
|
||||
for sid := range m.streams {
|
||||
sids = append(sids, sid)
|
||||
}
|
||||
return sids
|
||||
}
|
||||
|
||||
// GetStream returns the Stream object for a given ID.
|
||||
func (m *Multiplexer) GetStream(sid uint16) *Stream {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.streams[sid]
|
||||
}
|
||||
|
||||
// Reset clears all multiplexer state and closes all streams.
|
||||
func (m *Multiplexer) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, stream := range m.streams {
|
||||
stream.closed = true
|
||||
}
|
||||
|
||||
m.streams = make(map[uint16]*Stream)
|
||||
m.nextID = 1
|
||||
|
||||
m.sendSeqMu.Lock()
|
||||
m.sendSeq = make(map[uint16]uint32)
|
||||
m.sendSeqMu.Unlock()
|
||||
|
||||
m.bufferCond.Broadcast()
|
||||
}
|
||||
|
||||
// UpdateSendFunc updates the function used to transmit raw frames.
|
||||
func (m *Multiplexer) UpdateSendFunc(onSend func([]byte) error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onSend = onSend
|
||||
}
|
||||
|
||||
// WaitForData returns a channel that signals when new data is available for a stream.
|
||||
func (m *Multiplexer) WaitForData(sid uint16) <-chan struct{} {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
if _, ok := m.dataReady[sid]; !ok {
|
||||
m.dataReady[sid] = make(chan struct{}, 1)
|
||||
}
|
||||
return m.dataReady[sid]
|
||||
}
|
||||
|
||||
// CleanupDataChannel removes the data notification channel for a stream.
|
||||
func (m *Multiplexer) CleanupDataChannel(sid uint16) {
|
||||
m.dataReadyMu.Lock()
|
||||
defer m.dataReadyMu.Unlock()
|
||||
|
||||
if ch, ok := m.dataReady[sid]; ok {
|
||||
close(ch)
|
||||
delete(m.dataReady, sid)
|
||||
}
|
||||
}
|
||||
119
internal/muxconn/conn.go
Normal file
119
internal/muxconn/conn.go
Normal file
@@ -0,0 +1,119 @@
|
||||
// Package muxconn adapts a link.Link into an io.ReadWriteCloser suitable for
|
||||
// driving a smux session. The wrapper applies AEAD on every wire-bound write
|
||||
// and inverts it on every received message before exposing the bytes as a
|
||||
// byte stream.
|
||||
//
|
||||
// Link semantics are message-oriented: each Send produces exactly one OnData
|
||||
// on the peer. smux operates on a pure byte stream (header + payload may be
|
||||
// glued or split across reads). We bridge by:
|
||||
//
|
||||
// - Treating each Push as an opaque chunk appended to an internal byte
|
||||
// buffer that Read drains in arbitrary slices.
|
||||
// - Letting smux's sendLoop call Write once per frame; we encrypt and hand
|
||||
// the whole buffer to the link as a single message. Length boundaries
|
||||
// are preserved end-to-end by the transport (KCP length-prefix framing
|
||||
// in vp8channel, native message boundaries in datachannel).
|
||||
package muxconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
)
|
||||
|
||||
// ErrClosed is returned from Read/Write after the conn has been closed.
|
||||
var ErrClosed = errors.New("muxconn: closed")
|
||||
|
||||
// Conn is an io.ReadWriteCloser over a link.Link with optional AEAD wrapping.
|
||||
type Conn struct {
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
buf []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
// New wires a Conn over the given link. Push must be set as the link's OnData
|
||||
// callback before this conn is used.
|
||||
func New(ln link.Link, cipher *crypto.Cipher) *Conn {
|
||||
c := &Conn{ln: ln, cipher: cipher}
|
||||
c.cond = sync.NewCond(&c.mu)
|
||||
return c
|
||||
}
|
||||
|
||||
// Push hands an encrypted wire payload (one OnData event) to the conn.
|
||||
func (c *Conn) Push(ciphertext []byte) {
|
||||
pt, err := c.cipher.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
c.buf = append(c.buf, pt...)
|
||||
c.cond.Broadcast()
|
||||
}
|
||||
|
||||
// Read implements io.Reader. Blocks until at least one byte is available.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for !c.closed && len(c.buf) == 0 {
|
||||
c.cond.Wait()
|
||||
}
|
||||
if len(c.buf) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, c.buf)
|
||||
c.buf = c.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write encrypts p and ships it to the link as a single message. Blocks while
|
||||
// the link signals back-pressure.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
for {
|
||||
if c.isClosed() {
|
||||
return 0, ErrClosed
|
||||
}
|
||||
if c.ln.CanSend() {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
enc, err := c.cipher.Encrypt(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := c.ln.Send(enc); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close unblocks any pending Read with io.EOF.
|
||||
func (c *Conn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
c.cond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
@@ -11,44 +11,32 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/link"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/mux"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
||||
"github.com/openlibrecommunity/olcrtc/internal/names"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeySize is returned when the encryption key is not 32 bytes.
|
||||
ErrKeySize = errors.New("key must be 32 bytes")
|
||||
// ErrKeyStringLength is returned when the encryption key string length is not 32.
|
||||
ErrKeyStringLength = errors.New("key string length must be 32")
|
||||
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
|
||||
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
|
||||
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
|
||||
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
|
||||
// ErrNoLinks is returned when no links are available.
|
||||
ErrNoLinks = errors.New("no links available")
|
||||
// ErrDialProxy is returned when dialing the proxy fails.
|
||||
ErrDialProxy = errors.New("failed to dial proxy")
|
||||
// ErrEncryptFailed is returned when encryption fails.
|
||||
ErrEncryptFailed = errors.New("encrypt failed")
|
||||
)
|
||||
|
||||
// Server handles incoming tunnel connections and proxies their traffic.
|
||||
type Server struct {
|
||||
links []link.Link
|
||||
ln link.Link
|
||||
cipher *crypto.Cipher
|
||||
mux *mux.Multiplexer
|
||||
connections map[uint16]net.Conn
|
||||
connMu sync.RWMutex
|
||||
streamPumps map[uint16]net.Conn
|
||||
pumpMu sync.Mutex
|
||||
linkIdx atomic.Uint32
|
||||
activeClients atomic.Int32
|
||||
conn *muxconn.Conn
|
||||
session *smux.Session
|
||||
sessMu sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
dnsServer string
|
||||
resolver *net.Resolver
|
||||
@@ -97,25 +85,22 @@ func Run(
|
||||
|
||||
s := &Server{
|
||||
cipher: cipher,
|
||||
connections: make(map[uint16]net.Conn),
|
||||
streamPumps: make(map[uint16]net.Conn),
|
||||
links: make([]link.Link, 0),
|
||||
dnsServer: dnsServer,
|
||||
socksProxyAddr: socksProxyAddr,
|
||||
socksProxyPort: socksProxyPort,
|
||||
}
|
||||
|
||||
s.setupResolver()
|
||||
s.setupMux()
|
||||
|
||||
const linkCount = 1
|
||||
for i := range linkCount {
|
||||
if err := s.addLink(runCtx, linkName, transportName, carrierName, roomURL, i, cancel, videoWidth, videoHeight, videoFPS, videoBitrate, videoHW, videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS, vp8FPS, vp8BatchSize); err != nil {
|
||||
return fmt.Errorf("addLink failed: %w", err)
|
||||
}
|
||||
if err := s.bringUpLink(
|
||||
runCtx, linkName, transportName, carrierName, roomURL, cancel,
|
||||
videoWidth, videoHeight, videoFPS, videoBitrate, videoHW,
|
||||
videoQRSize, videoQRRecovery, videoCodec, videoTileModule, videoTileRS,
|
||||
vp8FPS, vp8BatchSize,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.runLoop(runCtx)
|
||||
err = s.serve(runCtx)
|
||||
|
||||
s.shutdown()
|
||||
s.wg.Wait()
|
||||
@@ -136,12 +121,7 @@ func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeySize, len(key))
|
||||
}
|
||||
|
||||
keyStr := string(key)
|
||||
if len(keyStr) != 32 {
|
||||
return nil, fmt.Errorf("%w, got %d", ErrKeyStringLength, len(keyStr))
|
||||
}
|
||||
|
||||
cipher, err := crypto.NewCipher(keyStr)
|
||||
cipher, err := crypto.NewCipher(string(key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
@@ -158,51 +138,30 @@ func (s *Server) setupResolver() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupMux() {
|
||||
s.mux = mux.New(0, func(frame []byte) error {
|
||||
for {
|
||||
canSend := true
|
||||
for _, ln := range s.links {
|
||||
if !ln.CanSend() {
|
||||
canSend = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canSend {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
encrypted, err := s.cipher.Encrypt(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
}
|
||||
if len(s.links) == 0 {
|
||||
return ErrNoLinks
|
||||
}
|
||||
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
|
||||
return s.links[idx].Send(encrypted)
|
||||
})
|
||||
// smuxConfig mirrors the client side. Both peers must agree on Version and
|
||||
// MaxFrameSize.
|
||||
func smuxConfig() *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.Version = 2
|
||||
cfg.MaxFrameSize = 32768
|
||||
cfg.MaxReceiveBuffer = 16 * 1024 * 1024
|
||||
cfg.MaxStreamBuffer = 1024 * 1024
|
||||
cfg.KeepAliveInterval = 10 * time.Second
|
||||
cfg.KeepAliveTimeout = 60 * time.Second
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (s *Server) addLink(
|
||||
func (s *Server) bringUpLink(
|
||||
ctx context.Context,
|
||||
linkName,
|
||||
transportName,
|
||||
carrierName,
|
||||
roomURL string,
|
||||
linkID int,
|
||||
linkName, transportName, carrierName, roomURL string,
|
||||
cancel context.CancelFunc,
|
||||
videoWidth, videoHeight, videoFPS int,
|
||||
videoBitrate, videoHW string,
|
||||
videoQRSize int,
|
||||
videoQRRecovery string,
|
||||
videoCodec string,
|
||||
videoTileModule int,
|
||||
videoTileRS int,
|
||||
vp8FPS int,
|
||||
vp8BatchSize int,
|
||||
videoTileModule, videoTileRS int,
|
||||
vp8FPS, vp8BatchSize int,
|
||||
) error {
|
||||
ln, err := link.New(ctx, linkName, link.Config{
|
||||
Transport: transportName,
|
||||
@@ -229,22 +188,21 @@ func (s *Server) addLink(
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create link: %w", err)
|
||||
}
|
||||
s.ln = ln
|
||||
|
||||
ln.SetEndedCallback(func(reason string) {
|
||||
logger.Infof("Server link %d reported conference end: %s", linkID, reason)
|
||||
logger.Infof("Server link reported conference end: %s", reason)
|
||||
cancel()
|
||||
})
|
||||
s.links = append(s.links, ln)
|
||||
ln.SetReconnectCallback(func() { s.handleReconnect() })
|
||||
|
||||
ln.SetReconnectCallback(func() {
|
||||
s.handleLinkReconnect(linkID)
|
||||
})
|
||||
|
||||
logger.Infof("Connecting link %d via %s/%s/%s...", linkID, linkName, transportName, carrierName)
|
||||
logger.Infof("Connecting link via %s/%s/%s...", linkName, transportName, carrierName)
|
||||
if err := ln.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect link: %w", err)
|
||||
}
|
||||
logger.Infof("Link %d connected", linkID)
|
||||
logger.Infof("Link connected")
|
||||
|
||||
s.installSession()
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
@@ -254,30 +212,195 @@ func (s *Server) addLink(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleLinkReconnect(linkID int) {
|
||||
logger.Infof("link %d reconnect event", linkID)
|
||||
|
||||
s.connMu.Lock()
|
||||
for sid, conn := range s.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(s.connections, sid)
|
||||
func (s *Server) installSession() {
|
||||
conn := muxconn.New(s.ln, s.cipher)
|
||||
sess, err := smux.Server(conn, smuxConfig())
|
||||
if err != nil {
|
||||
logger.Warnf("smux server init failed: %v", err)
|
||||
return
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
s.sessMu.Lock()
|
||||
s.conn = conn
|
||||
s.session = sess
|
||||
s.sessMu.Unlock()
|
||||
}
|
||||
|
||||
s.mux.UpdateSendFunc(func(frame []byte) error {
|
||||
encrypted, err := s.cipher.Encrypt(frame)
|
||||
func (s *Server) handleReconnect() {
|
||||
logger.Infof("server link reconnect — tearing down smux session")
|
||||
s.sessMu.Lock()
|
||||
if s.session != nil {
|
||||
_ = s.session.Close()
|
||||
s.session = nil
|
||||
}
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
s.sessMu.Unlock()
|
||||
s.installSession()
|
||||
}
|
||||
|
||||
func (s *Server) onData(data []byte) {
|
||||
s.sessMu.RLock()
|
||||
conn := s.conn
|
||||
s.sessMu.RUnlock()
|
||||
if conn != nil {
|
||||
conn.Push(data)
|
||||
}
|
||||
}
|
||||
|
||||
// serve drives the smux Accept loop, spawning a tunnel per inbound stream.
|
||||
// The loop tolerates session bounces (reconnects) by waiting until a fresh
|
||||
// session is installed instead of terminating the server.
|
||||
func (s *Server) serve(ctx context.Context) error {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.sessMu.RLock()
|
||||
sess := s.session
|
||||
s.sessMu.RUnlock()
|
||||
if sess == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrEncryptFailed, err)
|
||||
// Session is torn down (reconnect or close). If we're shutting
|
||||
// down, exit; otherwise wait for a new session and retry.
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
logger.Infof("AcceptStream returned %v — waiting for new session", err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if len(s.links) == 0 {
|
||||
return ErrNoLinks
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleStream(ctx, stream)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
s.sessMu.Lock()
|
||||
if s.session != nil {
|
||||
_ = s.session.Close()
|
||||
}
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
s.sessMu.Unlock()
|
||||
if s.ln != nil {
|
||||
_ = s.ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
// Read the connect JSON. The client writes the whole JSON in one
|
||||
// stream.Write so it usually arrives intact; tolerate fragmentation
|
||||
// by reading incrementally up to a sane cap.
|
||||
const maxConnReq = 4096
|
||||
header := make([]byte, 0, 256)
|
||||
tmp := make([]byte, 256)
|
||||
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
|
||||
for {
|
||||
n, err := stream.Read(tmp)
|
||||
if n > 0 {
|
||||
header = append(header, tmp[:n]...)
|
||||
if req, ok := parseConnectRequest(header); ok {
|
||||
_ = stream.SetReadDeadline(time.Time{})
|
||||
s.dispatch(stream, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
idx := s.linkIdx.Add(1) % uint32(len(s.links)) //nolint:gosec
|
||||
return s.links[idx].Send(encrypted)
|
||||
})
|
||||
s.mux.Reset()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(header) > maxConnReq {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(buf, &req); err != nil {
|
||||
return req, false
|
||||
}
|
||||
if req.Cmd != "connect" {
|
||||
return req, false
|
||||
}
|
||||
return req, true
|
||||
}
|
||||
|
||||
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
logger.Infof("sid=%d connect %s", stream.ID(), addr)
|
||||
|
||||
dialStart := time.Now()
|
||||
conn, err := s.dial(req)
|
||||
dialElapsed := time.Since(dialStart)
|
||||
|
||||
if err != nil {
|
||||
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
|
||||
|
||||
if _, err := stream.Write([]byte{0x00}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stream, conn)
|
||||
_ = stream.Close()
|
||||
}()
|
||||
_, _ = io.Copy(conn, stream)
|
||||
}
|
||||
|
||||
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
if s.socksProxyAddr == "" {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: s.resolver,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp4", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp4", proxyAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial proxy: %w", err)
|
||||
}
|
||||
|
||||
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
|
||||
@@ -318,336 +441,3 @@ func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) onData(data []byte) {
|
||||
plaintext, err := s.cipher.Decrypt(data)
|
||||
if err != nil {
|
||||
logger.Debugf("Decrypt error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if control, ok := mux.ParseControlFrame(plaintext); ok && control.Type == mux.ControlResetClient {
|
||||
logger.Infof("Received reset signal from client (clientID=%d)", control.ClientID)
|
||||
s.closeClientConnections(control.ClientID)
|
||||
}
|
||||
|
||||
s.mux.HandleFrame(plaintext)
|
||||
}
|
||||
|
||||
func (s *Server) closeClientConnections(clientID uint32) {
|
||||
s.connMu.Lock()
|
||||
defer s.connMu.Unlock()
|
||||
|
||||
for streamSid, conn := range s.connections {
|
||||
stream := s.mux.GetStream(streamSid)
|
||||
if stream != nil && stream.ClientID == clientID {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
delete(s.connections, streamSid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) runLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
s.processMuxStreams(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) shutdown() {
|
||||
s.connMu.Lock()
|
||||
for _, conn := range s.connections {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
s.pumpMu.Lock()
|
||||
for _, conn := range s.streamPumps {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
s.pumpMu.Unlock()
|
||||
|
||||
for i, tr := range s.links {
|
||||
logger.Infof("closing link %d", i)
|
||||
_ = tr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMuxStreams(ctx context.Context) {
|
||||
sids := s.mux.GetStreams()
|
||||
for _, sid := range sids {
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnection(sid)
|
||||
continue
|
||||
}
|
||||
|
||||
if s.hasConnection(sid) {
|
||||
continue
|
||||
}
|
||||
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var req ConnectRequest
|
||||
if err := json.Unmarshal(data, &req); err == nil && req.Cmd == "connect" {
|
||||
logger.Infof("sid=%d connect %s:%d", sid, req.Addr, req.Port)
|
||||
s.closeStreamConnection(sid)
|
||||
go s.handleConnect(ctx, sid, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) hasConnection(sid uint16) bool {
|
||||
s.connMu.RLock()
|
||||
defer s.connMu.RUnlock()
|
||||
return s.connections[sid] != nil
|
||||
}
|
||||
|
||||
func (s *Server) closeStreamConnection(sid uint16) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) closeStreamConnectionIfCurrent(sid uint16, expected net.Conn) {
|
||||
s.connMu.Lock()
|
||||
conn := s.connections[sid]
|
||||
if conn == expected {
|
||||
_ = conn.Close()
|
||||
delete(s.connections, sid)
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) markStreamPump(sid uint16, conn net.Conn) bool {
|
||||
s.pumpMu.Lock()
|
||||
defer s.pumpMu.Unlock()
|
||||
if current := s.streamPumps[sid]; current == conn {
|
||||
return false
|
||||
} else if current != nil {
|
||||
_ = current.Close()
|
||||
}
|
||||
s.streamPumps[sid] = conn
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) unmarkStreamPump(sid uint16, conn net.Conn) {
|
||||
s.pumpMu.Lock()
|
||||
if s.streamPumps[sid] == conn {
|
||||
delete(s.streamPumps, sid)
|
||||
}
|
||||
s.pumpMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handleConnect(ctx context.Context, sid uint16, req ConnectRequest) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
|
||||
s.closeStreamConnection(sid)
|
||||
|
||||
dialStart := time.Now()
|
||||
conn, err := s.dial(req)
|
||||
dialElapsed := time.Since(dialStart)
|
||||
|
||||
if err != nil {
|
||||
logger.Infof("sid=%d dial %s failed (%v): %v", sid, addr, dialElapsed, err)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
return
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
s.connections[sid] = conn
|
||||
s.connMu.Unlock()
|
||||
|
||||
logger.Infof("sid=%d connected %s in %v", sid, addr, dialElapsed)
|
||||
|
||||
s.activeClients.Add(1)
|
||||
_ = s.mux.SendData(sid, []byte{0x00})
|
||||
s.startStreamPump(ctx, sid, conn)
|
||||
|
||||
go s.pumpToMux(sid, conn)
|
||||
}
|
||||
|
||||
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
|
||||
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
||||
if s.socksProxyAddr == "" {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: s.resolver,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp4", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp4", proxyAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial proxy: %w", err)
|
||||
}
|
||||
|
||||
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) pumpToMux(sid uint16, conn net.Conn) {
|
||||
defer func() {
|
||||
s.activeClients.Add(-1)
|
||||
_ = s.mux.CloseStream(sid)
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, sid)
|
||||
s.connMu.Unlock()
|
||||
}()
|
||||
|
||||
// Decoupling queue: Read goroutine pushes here, sender goroutine drains
|
||||
// to mux.SendData. Without this, slow channel back-pressure stalls the
|
||||
// upstream Read which can cause TCP receive window to collapse to zero
|
||||
// and effectively wedge the connection (peer stops sending and never
|
||||
// resumes even though our channel is healthy).
|
||||
type chunk struct{ data []byte }
|
||||
queue := make(chan chunk, 64)
|
||||
doneSender := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(doneSender)
|
||||
for c := range queue {
|
||||
for !s.canSendData() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
if err := s.mux.SendData(sid, c.data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// queueHasSpace blocks until the decoupling queue has room or the
|
||||
// sender goroutine has exited. We wait here *before* arming the
|
||||
// upstream read deadline so that channel back-pressure isn't billed
|
||||
// to the socket as idle time and doesn't trip a spurious i/o timeout.
|
||||
queueHasSpace := func() bool {
|
||||
for {
|
||||
if len(queue) < cap(queue) {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-doneSender:
|
||||
return false
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
|
||||
// Idle timeout for genuinely dead upstreams. Only armed when we are
|
||||
// actively waiting on the socket (queue has space). During internal
|
||||
// back-pressure the deadline is not in effect, so flow-control pauses
|
||||
// don't get mis-classified as remote death.
|
||||
const idleReadTimeout = 60 * time.Second
|
||||
|
||||
for {
|
||||
if !queueHasSpace() {
|
||||
close(queue)
|
||||
<-doneSender
|
||||
return
|
||||
}
|
||||
|
||||
// Arm the deadline only when we actually want bytes from the peer.
|
||||
_ = conn.SetReadDeadline(time.Now().Add(idleReadTimeout))
|
||||
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
close(queue)
|
||||
<-doneSender
|
||||
return
|
||||
}
|
||||
|
||||
// Clear the deadline so it doesn't fire while we are blocked in
|
||||
// queueHasSpace() on the next iteration (back-pressure path).
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Copy because buf is reused on next Read.
|
||||
c := make([]byte, n)
|
||||
copy(c, buf[:n])
|
||||
|
||||
// Guaranteed non-blocking thanks to queueHasSpace() above (we are
|
||||
// the sole producer); the blocking fallback is just defensive.
|
||||
select {
|
||||
case queue <- chunk{data: c}:
|
||||
default:
|
||||
queue <- chunk{data: c}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startStreamPump(ctx context.Context, sid uint16, conn net.Conn) {
|
||||
if !s.markStreamPump(sid, conn) {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.unmarkStreamPump(sid, conn)
|
||||
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
data := s.mux.ReadStream(sid)
|
||||
if len(data) > 0 {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
_ = s.mux.CloseStream(sid)
|
||||
s.closeStreamConnectionIfCurrent(sid, conn)
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.mux.StreamClosed(sid) {
|
||||
s.closeStreamConnectionIfCurrent(sid, conn)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) canSendData() bool {
|
||||
for _, tr := range s.links {
|
||||
if !tr.CanSend() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user