Files
olcrtc/internal/server/server.go

954 lines
24 KiB
Go

// Package server implements the olcrtc tunnel server logic.
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/openlibrecommunity/olcrtc/internal/control"
"github.com/openlibrecommunity/olcrtc/internal/crypto"
"github.com/openlibrecommunity/olcrtc/internal/handshake"
"github.com/openlibrecommunity/olcrtc/internal/logger"
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
"github.com/openlibrecommunity/olcrtc/internal/names"
"github.com/openlibrecommunity/olcrtc/internal/runtime"
"github.com/openlibrecommunity/olcrtc/internal/transport"
"github.com/xtaci/smux"
)
const connectCommand = "connect"
var (
// ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with
// pre-runtime callers that errors.Is-checked it.
ErrKeyRequired = runtime.ErrKeyRequired
// ErrKeySize re-exports runtime.ErrKeySize for the same reason.
ErrKeySize = runtime.ErrKeySize
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
)
// SessionOpenFunc is called after a successful handshake, before the server
// accepts tunnel streams on that session.
type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any)
// SessionCloseFunc is called when a session is torn down. Possible reasons:
// "reconnect" (carrier dropped and was reestablished), "closed" (graceful
// shutdown or ctx cancel).
type SessionCloseFunc func(sessionID, reason string)
// TrafficFunc is called once per tunnel stream, after the copy loops finish.
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
// HealthFunc is called when the server control health snapshot changes.
type HealthFunc func(control.Status)
// Server handles incoming tunnel connections and proxies their traffic.
type Server struct {
ln transport.Transport
peerLn transport.PeerTransport
cipher *crypto.Cipher
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessMu sync.RWMutex
peerSessions map[string]*peerSession
reinstallMu sync.Mutex
wg sync.WaitGroup
authHook handshake.AuthFunc
onOpen SessionOpenFunc
onClose SessionCloseFunc
onTraffic TrafficFunc
deviceID string
sessionID string
dnsServer string
resolver *net.Resolver
socksProxyAddr string
socksProxyPort int
liveness control.Config
health *runtime.HealthTracker
done chan struct{}
doneOnce sync.Once
}
type peerSession struct {
peerID string
conn *muxconn.Conn
session *smux.Session
controlStrm *smux.Stream
controlStop context.CancelFunc
sessionID string
deviceID string
}
// ConnectRequest is a message from the client to establish a new connection.
type ConnectRequest struct {
Cmd string `json:"cmd"`
Addr string `json:"addr"`
Port int `json:"port"`
}
// Config holds runtime configuration for [Run].
type Config struct {
Transport string
Carrier string
RoomURL string
ChannelID string
KeyHex string
DNSServer string
SOCKSProxyAddr string
SOCKSProxyPort int
TransportOptions transport.Options
Engine string
URL string
Token string
Liveness control.Config
Traffic transport.TrafficConfig
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
// return a session ID. If nil, every client is admitted with a random UUID.
AuthHook handshake.AuthFunc
// OnSessionOpen fires after a successful handshake. Nil means no-op.
OnSessionOpen SessionOpenFunc
// OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op.
OnSessionClose SessionCloseFunc
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
OnTraffic TrafficFunc
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
OnHealth HealthFunc
}
// Run starts the server with the given configuration.
func Run(ctx context.Context, cfg Config) error {
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
cipher, err := setupCipher(cfg.KeyHex)
if err != nil {
return fmt.Errorf("setupCipher failed: %w", err)
}
hook := cfg.AuthHook
if hook == nil {
hook = defaultAuthHook
}
onOpen := cfg.OnSessionOpen
if onOpen == nil {
onOpen = func(string, string, map[string]any) {}
}
onClose := cfg.OnSessionClose
if onClose == nil {
onClose = func(string, string) {}
}
onTraffic := cfg.OnTraffic
if onTraffic == nil {
onTraffic = func(string, string, uint64, uint64) {}
}
s := &Server{
cipher: cipher,
authHook: hook,
onOpen: onOpen,
onClose: onClose,
onTraffic: onTraffic,
dnsServer: cfg.DNSServer,
socksProxyAddr: cfg.SOCKSProxyAddr,
socksProxyPort: cfg.SOCKSProxyPort,
liveness: cfg.Liveness,
health: runtime.NewHealthTracker(cfg.OnHealth),
peerSessions: make(map[string]*peerSession),
done: make(chan struct{}),
}
s.setupResolver()
// Register shutdown BEFORE bringUpLink so a partial setup (e.g.
// link.New succeeded but ln.Connect timed out) still tears the
// link down and sends MUC presence-unavailable. Without this, an
// early bringUpLink error returns straight to the caller and the
// already-joined MUC presence stays behind as a ghost participant
// for subsequent tests against the same room. shutdown is
// idempotent and safe to call before s.serve runs.
defer func() {
s.shutdown()
s.wg.Wait()
}()
if err := s.bringUpLink(runCtx, cfg, cancel); err != nil {
return err
}
go func() {
<-runCtx.Done()
s.closeSession()
}()
s.serve(runCtx)
return nil
}
func setupCipher(keyHex string) (*crypto.Cipher, error) {
cipher, err := runtime.SetupCipher(keyHex)
if err != nil {
return nil, fmt.Errorf("server: %w", err)
}
return cipher, nil
}
func (s *Server) setupResolver() {
s.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
d := net.Dialer{Timeout: 3 * time.Second}
return d.DialContext(ctx, network, s.dnsServer)
},
}
}
func smuxConfig(maxWirePayload int) *smux.Config {
return runtime.SmuxConfig(maxWirePayload)
}
func linkMaxPayload(tr transport.Transport) int {
return runtime.MaxPayload(tr)
}
func (s *Server) bringUpLink(
ctx context.Context,
cfg Config,
cancel context.CancelFunc,
) error {
ln, err := transport.New(ctx, cfg.Transport, transport.Config{
Carrier: cfg.Carrier,
RoomURL: cfg.RoomURL,
Engine: cfg.Engine,
URL: cfg.URL,
Token: cfg.Token,
ChannelID: cfg.ChannelID,
DeviceID: "",
Name: names.Generate(),
OnData: s.onData,
OnPeerData: s.onPeerData,
DNSServer: s.dnsServer,
ProxyAddr: s.socksProxyAddr,
ProxyPort: s.socksProxyPort,
Options: cfg.TransportOptions,
Traffic: cfg.Traffic,
})
if err != nil {
return fmt.Errorf("failed to create transport: %w", err)
}
s.ln = ln
if peerLn, ok := ln.(transport.PeerTransport); ok && peerLn.SupportsPeerRouting() {
s.peerLn = peerLn
}
ln.SetEndedCallback(func(reason string) {
logger.Infof("Server link reported conference end: %s", reason)
cancel()
})
ln.SetShouldReconnect(func() bool { return ctx.Err() == nil })
ln.SetReconnectCallback(func() {
if ctx.Err() != nil {
return
}
s.handleReconnect()
})
logger.Infof("Connecting transport=%s carrier=%s ...", cfg.Transport, cfg.Carrier)
if s.peerLn == nil {
s.installSession()
}
if err := ln.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect link: %w", err)
}
logger.Infof("Link connected")
s.wg.Add(1)
go func() {
defer s.wg.Done()
ln.WatchConnection(ctx)
}()
return nil
}
func (s *Server) installSession() {
conn := muxconn.New(s.ln, s.cipher)
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
return
}
s.sessMu.Lock()
s.conn = conn
s.session = sess
s.sessMu.Unlock()
}
func (s *Server) handleReconnect() {
s.recordReconnect()
logger.Infof("server reconnect reason=carrier - tearing down smux session")
s.sessMu.RLock()
current := s.session
s.sessMu.RUnlock()
s.reinstallSession(current)
}
func (s *Server) reinstallSession(dead *smux.Session) {
s.reinstallMu.Lock()
defer s.reinstallMu.Unlock()
// Pre-build the replacement so we can swap atomically below.
newConn := muxconn.New(s.ln, s.cipher)
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
logger.Warnf("smux server init failed: %v", err)
_ = newConn.Close()
return
}
s.sessMu.Lock()
if s.session != dead {
// Someone else already reinstalled — discard our build.
s.sessMu.Unlock()
_ = newSess.Close()
_ = newConn.Close()
return
}
oldSess := s.session
oldConn := s.conn
oldControl := s.controlStrm
oldControlStop := s.controlStop
oldSID := s.sessionID
s.session = newSess
s.conn = newConn
s.controlStrm = nil
s.controlStop = nil
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if oldControlStop != nil {
oldControlStop()
}
if oldSess != nil {
_ = oldSess.Close()
}
if oldConn != nil {
_ = oldConn.Close()
}
if oldControl != nil {
_ = oldControl.Close()
}
if oldSID != "" {
s.onClose(oldSID, "reconnect")
}
}
func (s *Server) closeSession() {
s.sessMu.Lock()
sess := s.session
conn := s.conn
control := s.controlStrm
controlStop := s.controlStop
peers := s.peerSessions
s.peerSessions = make(map[string]*peerSession)
s.session = nil
s.conn = nil
s.controlStrm = nil
s.controlStop = nil
oldSID := s.sessionID
s.sessionID = ""
s.deviceID = ""
s.sessMu.Unlock()
if controlStop != nil {
controlStop()
}
notifyControlClose(control)
if sess != nil {
_ = sess.Close()
}
if conn != nil {
_ = conn.Close()
}
if oldSID != "" {
s.onClose(oldSID, "closed")
}
for _, ps := range peers {
s.closePeerSession(ps, "closed")
}
}
func (s *Server) removePeerSession(peerID, reason string) {
s.sessMu.Lock()
ps := s.peerSessions[peerID]
delete(s.peerSessions, peerID)
s.sessMu.Unlock()
if ps != nil {
s.closePeerSession(ps, reason)
}
}
func (s *Server) closePeerSession(ps *peerSession, reason string) {
if ps.controlStop != nil {
ps.controlStop()
}
notifyControlClose(ps.controlStrm)
if ps.session != nil {
_ = ps.session.Close()
}
if ps.conn != nil {
_ = ps.conn.Close()
}
if ps.controlStrm != nil {
_ = ps.controlStrm.Close()
}
if ps.sessionID != "" {
s.onClose(ps.sessionID, reason)
}
}
func notifyControlClose(stream *smux.Stream) {
if stream == nil {
return
}
_ = stream.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err := control.SendClose(stream); err == nil {
time.Sleep(200 * time.Millisecond)
}
_ = stream.SetWriteDeadline(time.Time{})
_ = stream.CloseWrite()
}
func (s *Server) onData(data []byte) {
s.sessMu.RLock()
conn := s.conn
s.sessMu.RUnlock()
if conn != nil {
conn.Push(data)
}
}
func (s *Server) onPeerData(peerID string, data []byte) {
ps := s.getPeerSession(peerID)
if ps == nil {
return
}
ps.conn.Push(data)
}
func (s *Server) getPeerSession(peerID string) *peerSession {
if peerID == "" || s.peerLn == nil {
return nil
}
s.sessMu.Lock()
if ps := s.peerSessions[peerID]; ps != nil {
s.sessMu.Unlock()
return ps
}
conn := muxconn.NewPeer(s.peerLn, s.cipher, peerID)
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
if err != nil {
s.sessMu.Unlock()
logger.Warnf("smux server init failed for peer %s: %v", peerID, err)
_ = conn.Close()
return nil
}
ps := &peerSession{peerID: peerID, conn: conn, session: sess}
s.peerSessions[peerID] = ps
s.sessMu.Unlock()
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.servePeer(ps)
}()
return ps
}
// serve drives the smux Accept loop. The first accepted stream on a given
// smux session is the control stream — the handshake runs there. Subsequent
// streams are tunnel streams and proxy traffic.
func (s *Server) serve(ctx context.Context) {
if s.peerLn != nil {
<-ctx.Done()
return
}
s.serveSingle(ctx)
}
func (s *Server) serveSingle(ctx context.Context) {
for {
if contextDone(ctx) {
return
}
s.sessMu.RLock()
sess := s.session
s.sessMu.RUnlock()
if sess == nil {
select {
case <-ctx.Done():
return
case <-time.After(50 * time.Millisecond):
continue
}
}
if !s.handshakeReady() {
if !s.acceptHandshake(ctx, sess) {
continue
}
}
stream, err := sess.AcceptStream()
if err != nil {
if contextDone(ctx) {
return
}
logger.Debugf("AcceptStream returned %v - reinstalling session", err)
s.reinstallSession(sess)
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleStream(ctx, stream, s.currentSessionID())
}()
}
}
func (s *Server) currentSessionID() string {
s.sessMu.RLock()
defer s.sessMu.RUnlock()
return s.sessionID
}
func contextDone(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
// handshakeReady reports whether the current session has completed its
// handshake. The session is reset on reconnect, so this is recomputed.
func (s *Server) handshakeReady() bool {
s.sessMu.RLock()
defer s.sessMu.RUnlock()
return s.sessionID != ""
}
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
stream, err := sess.AcceptStream()
if err != nil {
select {
case <-ctx.Done():
return false
default:
}
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
s.resetLinkPeer()
s.reinstallSession(sess)
return false
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
hello, sid, err := handshake.Server(stream, s.authHook)
_ = stream.SetDeadline(time.Time{})
if err != nil {
logger.Warnf("handshake failed: %v", err)
_ = stream.Close()
s.resetLinkPeer()
s.reinstallSession(sess)
return false
}
s.sessMu.Lock()
s.deviceID = hello.DeviceID
s.sessionID = sid
s.sessMu.Unlock()
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
s.startControlLoop(ctx, sess, stream)
return true
}
func (s *Server) servePeer(ps *peerSession) {
if !s.acceptPeerHandshake(ps) {
s.removePeerSession(ps.peerID, "closed")
return
}
for {
if s.stopping() {
return
}
stream, err := ps.session.AcceptStream()
if err != nil {
if s.stopping() {
return
}
logger.Debugf("AcceptStream(peer=%s) returned %v - closing peer session", ps.peerID, err)
s.removePeerSession(ps.peerID, "closed")
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleStream(context.Background(), stream, ps.sessionID)
}()
}
}
func (s *Server) acceptPeerHandshake(ps *peerSession) bool {
stream, err := ps.session.AcceptStream()
if err != nil {
if !s.stopping() {
logger.Debugf("AcceptStream(control peer=%s) returned %v", ps.peerID, err)
}
return false
}
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
hello, sid, err := handshake.Server(stream, s.authHook)
_ = stream.SetDeadline(time.Time{})
if err != nil {
logger.Warnf("handshake failed peer=%s: %v", ps.peerID, err)
_ = stream.Close()
return false
}
ps.controlStrm = stream
ps.deviceID = hello.DeviceID
ps.sessionID = sid
s.recordSession(sid)
s.onOpen(sid, hello.DeviceID, hello.Claims)
logger.Infof("session %s opened (device=%s peer=%s)", sid, hello.DeviceID, ps.peerID)
s.startPeerControlLoop(ps, stream)
return true
}
func (s *Server) resetLinkPeer() {
s.sessMu.RLock()
ln := s.ln
s.sessMu.RUnlock()
if resetter, ok := ln.(interface{ ResetPeer() }); ok {
resetter.ResetPeer()
}
}
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(ctx)
s.sessMu.Lock()
s.controlStrm = stream
s.controlStop = stop
s.sessMu.Unlock()
liveness := s.liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
s.sessMu.RLock()
sid := s.sessionID
s.sessMu.RUnlock()
s.recordPong(h)
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
s.recordMissed(missed)
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
s.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || ctx.Err() != nil {
return
}
if err != nil {
logger.Warnf("server control stream ended: %v", err)
}
s.recordReconnect()
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
s.resetLinkPeer()
s.reinstallSession(sess)
// Tell the carrier to rebuild itself too. Without this the SFU side
// keeps its dead PC around and the client's reconnect handshakes
// keep landing in the void until the carrier eventually notices on
// its own (which observationally takes ~40s on a Telemost room).
if s.ln != nil {
s.ln.Reconnect("liveness")
}
}()
}
func (s *Server) startPeerControlLoop(ps *peerSession, stream *smux.Stream) {
controlCtx, stop := context.WithCancel(context.Background())
ps.controlStop = stop
liveness := s.liveness
onPong := liveness.OnPong
onMissedPong := liveness.OnMissedPong
onUnhealthy := liveness.OnUnhealthy
liveness.OnPong = func(h control.Health) {
s.recordPong(h)
logger.Debugf("control alive session=%s peer=%s rtt=%v seq=%d", ps.sessionID, ps.peerID, h.RTT, h.Seq)
if onPong != nil {
onPong(h)
}
}
liveness.OnMissedPong = func(missed int) {
s.recordMissed(missed)
logger.Warnf("control missed pong on server: session=%s peer=%s missed_pongs=%d",
ps.sessionID, ps.peerID, missed)
if onMissedPong != nil {
onMissedPong(missed)
}
}
liveness.OnUnhealthy = func(missed int) {
s.recordUnhealthy(missed)
logger.Warnf("control stream unhealthy on server: session=%s peer=%s missed_pongs=%d",
ps.sessionID, ps.peerID, missed)
if onUnhealthy != nil {
onUnhealthy(missed)
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() { _ = stream.Close() }()
err := control.Run(controlCtx, stream, liveness)
if controlCtx.Err() != nil || s.stopping() {
return
}
if err != nil {
logger.Warnf("server control stream ended session=%s peer=%s: %v", ps.sessionID, ps.peerID, err)
}
s.recordReconnect()
s.removePeerSession(ps.peerID, "reconnect")
}()
}
func (s *Server) stopping() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Status returns the latest server-side control health snapshot.
func (s *Server) Status() control.Status {
return s.health.Status()
}
func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) }
func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) }
func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) }
func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) }
func (s *Server) recordReconnect() { s.health.RecordReconnect() }
func (s *Server) shutdown() {
if s.done != nil {
s.doneOnce.Do(func() { close(s.done) })
}
s.closeSession()
if s.ln != nil {
_ = s.ln.Close()
}
}
func (s *Server) handleStream(_ context.Context, stream *smux.Stream, sessionID string) {
defer func() { _ = stream.Close() }()
if sessionID == "" {
sessionID = s.currentSessionID()
}
// Read the connect JSON. The client writes the whole JSON in one
// stream.Write so it usually arrives intact; tolerate fragmentation
// by reading incrementally up to a sane cap.
const maxConnReq = 4096
header := make([]byte, 0, 256)
tmp := make([]byte, 256)
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
for {
n, err := stream.Read(tmp)
if n > 0 {
header = append(header, tmp[:n]...)
if req, ok := parseConnectRequest(header); ok {
_ = stream.SetReadDeadline(time.Time{})
s.dispatch(stream, req, sessionID)
return
}
}
if err != nil {
return
}
if len(header) > maxConnReq {
return
}
}
}
func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
var req ConnectRequest
if err := json.Unmarshal(buf, &req); err != nil {
return req, false
}
if req.Cmd != connectCommand {
return req, false
}
return req, true
}
// defaultAuthHook admits every client and assigns a random session ID.
// Replace it via [Config.AuthHook] to plug in real authorization.
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
return uuid.NewString(), nil
}
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest, sessionID string) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
logger.Infof("sid=%d connect %s", stream.ID(), addr)
dialStart := time.Now()
conn, err := s.dial(req)
dialElapsed := time.Since(dialStart)
if err != nil {
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
return
}
defer func() { _ = conn.Close() }()
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
if _, err := stream.Write([]byte{0x00}); err != nil {
return
}
var bytesOut uint64
done := make(chan struct{})
go func() {
n, _ := io.Copy(stream, conn)
if n > 0 {
bytesOut = uint64(n)
}
_ = stream.Close()
close(done)
}()
in, _ := io.Copy(conn, stream)
_ = conn.Close()
<-done
bytesIn := uint64(0)
if in > 0 {
bytesIn = uint64(in)
}
if s.onTraffic != nil {
s.onTraffic(sessionID, addr, bytesIn, bytesOut)
}
}
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
if s.socksProxyAddr == "" {
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: s.resolver,
}
conn, err := dialer.Dial("tcp4", addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
conn, err := dialer.Dial("tcp4", proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy: %w", err)
}
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
return fmt.Errorf("failed to write socks5 auth: %w", err)
}
resp := make([]byte, 2)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return ErrSocks5AuthFailed
}
addrLen := len(targetAddr)
if addrLen > 255 {
addrLen = 255
targetAddr = targetAddr[:255]
}
req := make([]byte, 0, 7+addrLen)
req = append(req, 5, 1, 0, 3, byte(addrLen))
req = append(req, []byte(targetAddr)...)
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
if _, err := conn.Write(req); err != nil {
return fmt.Errorf("failed to write socks5 connect req: %w", err)
}
resp = make([]byte, 10)
if _, err := io.ReadFull(conn, resp); err != nil {
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
}
if resp[0] != 5 || resp[1] != 0 {
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
}
return nil
}