mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 15:13:40 +00:00
706 lines
18 KiB
Go
706 lines
18 KiB
Go
// Package server implements the olcrtc tunnel server logic.
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/openlibrecommunity/olcrtc/internal/control"
|
|
"github.com/openlibrecommunity/olcrtc/internal/crypto"
|
|
"github.com/openlibrecommunity/olcrtc/internal/handshake"
|
|
"github.com/openlibrecommunity/olcrtc/internal/logger"
|
|
"github.com/openlibrecommunity/olcrtc/internal/muxconn"
|
|
"github.com/openlibrecommunity/olcrtc/internal/names"
|
|
"github.com/openlibrecommunity/olcrtc/internal/runtime"
|
|
"github.com/openlibrecommunity/olcrtc/internal/transport"
|
|
"github.com/xtaci/smux"
|
|
)
|
|
|
|
const connectCommand = "connect"
|
|
|
|
var (
|
|
// ErrKeyRequired re-exports runtime.ErrKeyRequired for compatibility with
|
|
// pre-runtime callers that errors.Is-checked it.
|
|
ErrKeyRequired = runtime.ErrKeyRequired
|
|
// ErrKeySize re-exports runtime.ErrKeySize for the same reason.
|
|
ErrKeySize = runtime.ErrKeySize
|
|
// ErrSocks5AuthFailed is returned when SOCKS5 authentication fails.
|
|
ErrSocks5AuthFailed = errors.New("SOCKS5 auth failed")
|
|
// ErrSocks5ConnectFailed is returned when SOCKS5 connection fails.
|
|
ErrSocks5ConnectFailed = errors.New("SOCKS5 connect failed")
|
|
)
|
|
|
|
// SessionOpenFunc is called after a successful handshake, before the server
|
|
// accepts tunnel streams on that session.
|
|
type SessionOpenFunc func(sessionID, deviceID string, claims map[string]any)
|
|
|
|
// SessionCloseFunc is called when a session is torn down. Possible reasons:
|
|
// "reconnect" (carrier dropped and was reestablished), "closed" (graceful
|
|
// shutdown or ctx cancel).
|
|
type SessionCloseFunc func(sessionID, reason string)
|
|
|
|
// TrafficFunc is called once per tunnel stream, after the copy loops finish.
|
|
// bytesIn counts client→target bytes; bytesOut counts target→client bytes.
|
|
type TrafficFunc func(sessionID, addr string, bytesIn, bytesOut uint64)
|
|
|
|
// HealthFunc is called when the server control health snapshot changes.
|
|
type HealthFunc func(control.Status)
|
|
|
|
// Server handles incoming tunnel connections and proxies their traffic.
|
|
type Server struct {
|
|
ln transport.Transport
|
|
cipher *crypto.Cipher
|
|
conn *muxconn.Conn
|
|
session *smux.Session
|
|
controlStop context.CancelFunc
|
|
sessMu sync.RWMutex
|
|
reinstallMu sync.Mutex
|
|
wg sync.WaitGroup
|
|
authHook handshake.AuthFunc
|
|
onOpen SessionOpenFunc
|
|
onClose SessionCloseFunc
|
|
onTraffic TrafficFunc
|
|
deviceID string
|
|
sessionID string
|
|
dnsServer string
|
|
resolver *net.Resolver
|
|
socksProxyAddr string
|
|
socksProxyPort int
|
|
liveness control.Config
|
|
health *runtime.HealthTracker
|
|
}
|
|
|
|
// ConnectRequest is a message from the client to establish a new connection.
|
|
type ConnectRequest struct {
|
|
Cmd string `json:"cmd"`
|
|
Addr string `json:"addr"`
|
|
Port int `json:"port"`
|
|
}
|
|
|
|
// Config holds runtime configuration for [Run].
|
|
type Config struct {
|
|
Transport string
|
|
Carrier string
|
|
RoomURL string
|
|
ChannelID string
|
|
KeyHex string
|
|
DNSServer string
|
|
SOCKSProxyAddr string
|
|
SOCKSProxyPort int
|
|
TransportOptions transport.Options
|
|
Engine string
|
|
URL string
|
|
Token string
|
|
Liveness control.Config
|
|
Traffic transport.TrafficConfig
|
|
|
|
// AuthHook is invoked after CLIENT_HELLO to authorize the client and
|
|
// return a session ID. If nil, every client is admitted with a random UUID.
|
|
AuthHook handshake.AuthFunc
|
|
|
|
// OnSessionOpen fires after a successful handshake. Nil means no-op.
|
|
OnSessionOpen SessionOpenFunc
|
|
// OnSessionClose fires when the session is torn down (reconnect, closed). Nil means no-op.
|
|
OnSessionClose SessionCloseFunc
|
|
// OnTraffic fires once per tunnel stream after both copy loops finish. Nil means no-op.
|
|
OnTraffic TrafficFunc
|
|
// OnHealth fires when liveness/reconnect status changes. Nil means no-op.
|
|
OnHealth HealthFunc
|
|
}
|
|
|
|
// Run starts the server with the given configuration.
|
|
func Run(ctx context.Context, cfg Config) error {
|
|
runCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
cipher, err := setupCipher(cfg.KeyHex)
|
|
if err != nil {
|
|
return fmt.Errorf("setupCipher failed: %w", err)
|
|
}
|
|
|
|
hook := cfg.AuthHook
|
|
if hook == nil {
|
|
hook = defaultAuthHook
|
|
}
|
|
onOpen := cfg.OnSessionOpen
|
|
if onOpen == nil {
|
|
onOpen = func(string, string, map[string]any) {}
|
|
}
|
|
onClose := cfg.OnSessionClose
|
|
if onClose == nil {
|
|
onClose = func(string, string) {}
|
|
}
|
|
onTraffic := cfg.OnTraffic
|
|
if onTraffic == nil {
|
|
onTraffic = func(string, string, uint64, uint64) {}
|
|
}
|
|
s := &Server{
|
|
cipher: cipher,
|
|
authHook: hook,
|
|
onOpen: onOpen,
|
|
onClose: onClose,
|
|
onTraffic: onTraffic,
|
|
dnsServer: cfg.DNSServer,
|
|
socksProxyAddr: cfg.SOCKSProxyAddr,
|
|
socksProxyPort: cfg.SOCKSProxyPort,
|
|
liveness: cfg.Liveness,
|
|
health: runtime.NewHealthTracker(cfg.OnHealth),
|
|
}
|
|
s.setupResolver()
|
|
|
|
// Register shutdown BEFORE bringUpLink so a partial setup (e.g.
|
|
// link.New succeeded but ln.Connect timed out) still tears the
|
|
// link down and sends MUC presence-unavailable. Without this, an
|
|
// early bringUpLink error returns straight to the caller and the
|
|
// already-joined MUC presence stays behind as a ghost participant
|
|
// for subsequent tests against the same room. shutdown is
|
|
// idempotent and safe to call before s.serve runs.
|
|
defer func() {
|
|
s.shutdown()
|
|
s.wg.Wait()
|
|
}()
|
|
|
|
if err := s.bringUpLink(runCtx, cfg, cancel); err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
<-runCtx.Done()
|
|
s.closeSession()
|
|
}()
|
|
|
|
s.serve(runCtx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func setupCipher(keyHex string) (*crypto.Cipher, error) {
|
|
cipher, err := runtime.SetupCipher(keyHex)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("server: %w", err)
|
|
}
|
|
return cipher, nil
|
|
}
|
|
|
|
func (s *Server) setupResolver() {
|
|
s.resolver = &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
|
d := net.Dialer{Timeout: 3 * time.Second}
|
|
return d.DialContext(ctx, network, s.dnsServer)
|
|
},
|
|
}
|
|
}
|
|
|
|
func smuxConfig(maxWirePayload int) *smux.Config {
|
|
return runtime.SmuxConfig(maxWirePayload)
|
|
}
|
|
|
|
func linkMaxPayload(tr transport.Transport) int {
|
|
return runtime.MaxPayload(tr)
|
|
}
|
|
|
|
func (s *Server) bringUpLink(
|
|
ctx context.Context,
|
|
cfg Config,
|
|
cancel context.CancelFunc,
|
|
) error {
|
|
ln, err := transport.New(ctx, cfg.Transport, transport.Config{
|
|
Carrier: cfg.Carrier,
|
|
RoomURL: cfg.RoomURL,
|
|
Engine: cfg.Engine,
|
|
URL: cfg.URL,
|
|
Token: cfg.Token,
|
|
ChannelID: cfg.ChannelID,
|
|
DeviceID: "",
|
|
Name: names.Generate(),
|
|
OnData: s.onData,
|
|
DNSServer: s.dnsServer,
|
|
ProxyAddr: s.socksProxyAddr,
|
|
ProxyPort: s.socksProxyPort,
|
|
Options: cfg.TransportOptions,
|
|
Traffic: cfg.Traffic,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create transport: %w", err)
|
|
}
|
|
s.ln = ln
|
|
|
|
ln.SetEndedCallback(func(reason string) {
|
|
logger.Infof("Server link reported conference end: %s", reason)
|
|
cancel()
|
|
})
|
|
ln.SetShouldReconnect(func() bool { return ctx.Err() == nil })
|
|
ln.SetReconnectCallback(func() {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
s.handleReconnect()
|
|
})
|
|
|
|
logger.Infof("Connecting transport=%s carrier=%s ...", cfg.Transport, cfg.Carrier)
|
|
s.installSession()
|
|
|
|
if err := ln.Connect(ctx); err != nil {
|
|
return fmt.Errorf("failed to connect link: %w", err)
|
|
}
|
|
logger.Infof("Link connected")
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
ln.WatchConnection(ctx)
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) installSession() {
|
|
conn := muxconn.New(s.ln, s.cipher)
|
|
sess, err := smux.Server(conn, smuxConfig(linkMaxPayload(s.ln)))
|
|
if err != nil {
|
|
logger.Warnf("smux server init failed: %v", err)
|
|
return
|
|
}
|
|
s.sessMu.Lock()
|
|
s.conn = conn
|
|
s.session = sess
|
|
s.sessMu.Unlock()
|
|
}
|
|
|
|
func (s *Server) handleReconnect() {
|
|
s.recordReconnect()
|
|
logger.Infof("server reconnect reason=carrier - tearing down smux session")
|
|
s.sessMu.RLock()
|
|
current := s.session
|
|
s.sessMu.RUnlock()
|
|
s.reinstallSession(current)
|
|
}
|
|
|
|
func (s *Server) reinstallSession(dead *smux.Session) {
|
|
s.reinstallMu.Lock()
|
|
defer s.reinstallMu.Unlock()
|
|
|
|
// Pre-build the replacement so we can swap atomically below.
|
|
newConn := muxconn.New(s.ln, s.cipher)
|
|
newSess, err := smux.Server(newConn, smuxConfig(linkMaxPayload(s.ln)))
|
|
if err != nil {
|
|
logger.Warnf("smux server init failed: %v", err)
|
|
_ = newConn.Close()
|
|
return
|
|
}
|
|
|
|
s.sessMu.Lock()
|
|
if s.session != dead {
|
|
// Someone else already reinstalled — discard our build.
|
|
s.sessMu.Unlock()
|
|
_ = newSess.Close()
|
|
_ = newConn.Close()
|
|
return
|
|
}
|
|
oldSess := s.session
|
|
oldConn := s.conn
|
|
oldControlStop := s.controlStop
|
|
oldSID := s.sessionID
|
|
s.session = newSess
|
|
s.conn = newConn
|
|
s.controlStop = nil
|
|
s.sessionID = ""
|
|
s.deviceID = ""
|
|
s.sessMu.Unlock()
|
|
|
|
if oldControlStop != nil {
|
|
oldControlStop()
|
|
}
|
|
if oldSess != nil {
|
|
_ = oldSess.Close()
|
|
}
|
|
if oldConn != nil {
|
|
_ = oldConn.Close()
|
|
}
|
|
if oldSID != "" {
|
|
s.onClose(oldSID, "reconnect")
|
|
}
|
|
}
|
|
|
|
func (s *Server) closeSession() {
|
|
s.sessMu.Lock()
|
|
sess := s.session
|
|
conn := s.conn
|
|
controlStop := s.controlStop
|
|
s.session = nil
|
|
s.conn = nil
|
|
s.controlStop = nil
|
|
oldSID := s.sessionID
|
|
s.sessionID = ""
|
|
s.deviceID = ""
|
|
s.sessMu.Unlock()
|
|
|
|
if controlStop != nil {
|
|
controlStop()
|
|
}
|
|
if sess != nil {
|
|
_ = sess.Close()
|
|
}
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
if oldSID != "" {
|
|
s.onClose(oldSID, "closed")
|
|
}
|
|
}
|
|
|
|
func (s *Server) onData(data []byte) {
|
|
s.sessMu.RLock()
|
|
conn := s.conn
|
|
s.sessMu.RUnlock()
|
|
if conn != nil {
|
|
conn.Push(data)
|
|
}
|
|
}
|
|
|
|
// serve drives the smux Accept loop. The first accepted stream on a given
|
|
// smux session is the control stream — the handshake runs there. Subsequent
|
|
// streams are tunnel streams and proxy traffic.
|
|
func (s *Server) serve(ctx context.Context) {
|
|
for {
|
|
if contextDone(ctx) {
|
|
return
|
|
}
|
|
|
|
s.sessMu.RLock()
|
|
sess := s.session
|
|
s.sessMu.RUnlock()
|
|
if sess == nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(50 * time.Millisecond):
|
|
continue
|
|
}
|
|
}
|
|
|
|
if !s.handshakeReady() {
|
|
if !s.acceptHandshake(ctx, sess) {
|
|
continue
|
|
}
|
|
}
|
|
|
|
stream, err := sess.AcceptStream()
|
|
if err != nil {
|
|
if contextDone(ctx) {
|
|
return
|
|
}
|
|
logger.Debugf("AcceptStream returned %v - reinstalling session", err)
|
|
s.reinstallSession(sess)
|
|
continue
|
|
}
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
s.handleStream(ctx, stream)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func contextDone(ctx context.Context) bool {
|
|
select {
|
|
case <-ctx.Done():
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// handshakeReady reports whether the current session has completed its
|
|
// handshake. The session is reset on reconnect, so this is recomputed.
|
|
func (s *Server) handshakeReady() bool {
|
|
s.sessMu.RLock()
|
|
defer s.sessMu.RUnlock()
|
|
return s.sessionID != ""
|
|
}
|
|
|
|
func (s *Server) acceptHandshake(ctx context.Context, sess *smux.Session) bool {
|
|
stream, err := sess.AcceptStream()
|
|
if err != nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
default:
|
|
}
|
|
logger.Debugf("AcceptStream(control) returned %v - reinstalling session", err)
|
|
s.resetLinkPeer()
|
|
s.reinstallSession(sess)
|
|
return false
|
|
}
|
|
_ = stream.SetDeadline(time.Now().Add(handshake.DefaultTimeout))
|
|
hello, sid, err := handshake.Server(stream, s.authHook)
|
|
_ = stream.SetDeadline(time.Time{})
|
|
if err != nil {
|
|
logger.Warnf("handshake failed: %v", err)
|
|
_ = stream.Close()
|
|
s.resetLinkPeer()
|
|
s.reinstallSession(sess)
|
|
return false
|
|
}
|
|
s.sessMu.Lock()
|
|
s.deviceID = hello.DeviceID
|
|
s.sessionID = sid
|
|
s.sessMu.Unlock()
|
|
s.recordSession(sid)
|
|
s.onOpen(sid, hello.DeviceID, hello.Claims)
|
|
logger.Infof("session %s opened (device=%s)", sid, hello.DeviceID)
|
|
s.startControlLoop(ctx, sess, stream)
|
|
return true
|
|
}
|
|
|
|
func (s *Server) resetLinkPeer() {
|
|
s.sessMu.RLock()
|
|
ln := s.ln
|
|
s.sessMu.RUnlock()
|
|
if resetter, ok := ln.(interface{ ResetPeer() }); ok {
|
|
resetter.ResetPeer()
|
|
}
|
|
}
|
|
|
|
func (s *Server) startControlLoop(ctx context.Context, sess *smux.Session, stream *smux.Stream) {
|
|
controlCtx, stop := context.WithCancel(ctx)
|
|
s.sessMu.Lock()
|
|
s.controlStop = stop
|
|
s.sessMu.Unlock()
|
|
|
|
liveness := s.liveness
|
|
onPong := liveness.OnPong
|
|
onMissedPong := liveness.OnMissedPong
|
|
onUnhealthy := liveness.OnUnhealthy
|
|
liveness.OnPong = func(h control.Health) {
|
|
s.sessMu.RLock()
|
|
sid := s.sessionID
|
|
s.sessMu.RUnlock()
|
|
s.recordPong(h)
|
|
logger.Debugf("control alive session=%s rtt=%v seq=%d", sid, h.RTT, h.Seq)
|
|
if onPong != nil {
|
|
onPong(h)
|
|
}
|
|
}
|
|
liveness.OnMissedPong = func(missed int) {
|
|
s.recordMissed(missed)
|
|
logger.Warnf("control missed pong on server: missed_pongs=%d", missed)
|
|
if onMissedPong != nil {
|
|
onMissedPong(missed)
|
|
}
|
|
}
|
|
liveness.OnUnhealthy = func(missed int) {
|
|
s.recordUnhealthy(missed)
|
|
logger.Warnf("control stream unhealthy on server: missed_pongs=%d", missed)
|
|
if onUnhealthy != nil {
|
|
onUnhealthy(missed)
|
|
}
|
|
}
|
|
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
defer func() { _ = stream.Close() }()
|
|
err := control.Run(controlCtx, stream, liveness)
|
|
if controlCtx.Err() != nil || ctx.Err() != nil {
|
|
return
|
|
}
|
|
if err != nil {
|
|
logger.Warnf("server control stream ended: %v", err)
|
|
}
|
|
s.recordReconnect()
|
|
logger.Infof("server reconnect reason=liveness - reinstalling smux session")
|
|
s.reinstallSession(sess)
|
|
}()
|
|
}
|
|
|
|
// Status returns the latest server-side control health snapshot.
|
|
func (s *Server) Status() control.Status {
|
|
return s.health.Status()
|
|
}
|
|
|
|
func (s *Server) recordSession(sessionID string) { s.health.RecordSession(sessionID) }
|
|
func (s *Server) recordPong(h control.Health) { s.health.RecordPong(h) }
|
|
func (s *Server) recordMissed(missed int) { s.health.RecordMissed(missed) }
|
|
func (s *Server) recordUnhealthy(missed int) { s.health.RecordUnhealthy(missed) }
|
|
func (s *Server) recordReconnect() { s.health.RecordReconnect() }
|
|
|
|
func (s *Server) shutdown() {
|
|
s.closeSession()
|
|
if s.ln != nil {
|
|
_ = s.ln.Close()
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleStream(_ context.Context, stream *smux.Stream) {
|
|
defer func() { _ = stream.Close() }()
|
|
|
|
// Read the connect JSON. The client writes the whole JSON in one
|
|
// stream.Write so it usually arrives intact; tolerate fragmentation
|
|
// by reading incrementally up to a sane cap.
|
|
const maxConnReq = 4096
|
|
header := make([]byte, 0, 256)
|
|
tmp := make([]byte, 256)
|
|
_ = stream.SetReadDeadline(time.Now().Add(15 * time.Second))
|
|
for {
|
|
n, err := stream.Read(tmp)
|
|
if n > 0 {
|
|
header = append(header, tmp[:n]...)
|
|
if req, ok := parseConnectRequest(header); ok {
|
|
_ = stream.SetReadDeadline(time.Time{})
|
|
s.dispatch(stream, req)
|
|
return
|
|
}
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(header) > maxConnReq {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseConnectRequest(buf []byte) (ConnectRequest, bool) {
|
|
var req ConnectRequest
|
|
if err := json.Unmarshal(buf, &req); err != nil {
|
|
return req, false
|
|
}
|
|
if req.Cmd != connectCommand {
|
|
return req, false
|
|
}
|
|
return req, true
|
|
}
|
|
|
|
// defaultAuthHook admits every client and assigns a random session ID.
|
|
// Replace it via [Config.AuthHook] to plug in real authorization.
|
|
func defaultAuthHook(_ string, _ map[string]any) (string, error) {
|
|
return uuid.NewString(), nil
|
|
}
|
|
|
|
func (s *Server) dispatch(stream *smux.Stream, req ConnectRequest) {
|
|
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
|
logger.Infof("sid=%d connect %s", stream.ID(), addr)
|
|
|
|
s.sessMu.RLock()
|
|
sid := s.sessionID
|
|
s.sessMu.RUnlock()
|
|
|
|
dialStart := time.Now()
|
|
conn, err := s.dial(req)
|
|
dialElapsed := time.Since(dialStart)
|
|
|
|
if err != nil {
|
|
logger.Infof("sid=%d dial %s failed (%v): %v", stream.ID(), addr, dialElapsed, err)
|
|
return
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
logger.Infof("sid=%d connected %s in %v", stream.ID(), addr, dialElapsed)
|
|
|
|
if _, err := stream.Write([]byte{0x00}); err != nil {
|
|
return
|
|
}
|
|
|
|
var bytesOut uint64
|
|
done := make(chan struct{})
|
|
go func() {
|
|
n, _ := io.Copy(stream, conn)
|
|
if n > 0 {
|
|
bytesOut = uint64(n)
|
|
}
|
|
_ = stream.Close()
|
|
close(done)
|
|
}()
|
|
in, _ := io.Copy(conn, stream)
|
|
_ = conn.Close()
|
|
<-done
|
|
bytesIn := uint64(0)
|
|
if in > 0 {
|
|
bytesIn = uint64(in)
|
|
}
|
|
if s.onTraffic != nil {
|
|
s.onTraffic(sid, addr, bytesIn, bytesOut)
|
|
}
|
|
}
|
|
|
|
func (s *Server) dial(req ConnectRequest) (net.Conn, error) {
|
|
addr := net.JoinHostPort(req.Addr, strconv.Itoa(req.Port))
|
|
if s.socksProxyAddr == "" {
|
|
dialer := &net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
Resolver: s.resolver,
|
|
}
|
|
conn, err := dialer.Dial("tcp4", addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial failed: %w", err)
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
proxyAddr := net.JoinHostPort(s.socksProxyAddr, strconv.Itoa(s.socksProxyPort))
|
|
dialer := &net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
conn, err := dialer.Dial("tcp4", proxyAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial proxy: %w", err)
|
|
}
|
|
|
|
if err := s.socks5Connect(conn, req.Addr, req.Port); err != nil {
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func (s *Server) socks5Connect(conn net.Conn, targetAddr string, targetPort int) error {
|
|
if _, err := conn.Write([]byte{5, 1, 0}); err != nil {
|
|
return fmt.Errorf("failed to write socks5 auth: %w", err)
|
|
}
|
|
|
|
resp := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, resp); err != nil {
|
|
return fmt.Errorf("failed to read socks5 auth resp: %w", err)
|
|
}
|
|
if resp[0] != 5 || resp[1] != 0 {
|
|
return ErrSocks5AuthFailed
|
|
}
|
|
|
|
addrLen := len(targetAddr)
|
|
if addrLen > 255 {
|
|
addrLen = 255
|
|
targetAddr = targetAddr[:255]
|
|
}
|
|
|
|
req := make([]byte, 0, 7+addrLen)
|
|
req = append(req, 5, 1, 0, 3, byte(addrLen))
|
|
req = append(req, []byte(targetAddr)...)
|
|
req = append(req, byte(targetPort>>8), byte(targetPort)) //nolint:gosec,lll // G115: bounded conversion verified by surrounding logic
|
|
|
|
if _, err := conn.Write(req); err != nil {
|
|
return fmt.Errorf("failed to write socks5 connect req: %w", err)
|
|
}
|
|
|
|
resp = make([]byte, 10)
|
|
if _, err := io.ReadFull(conn, resp); err != nil {
|
|
return fmt.Errorf("failed to read socks5 connect resp: %w", err)
|
|
}
|
|
if resp[0] != 5 || resp[1] != 0 {
|
|
return fmt.Errorf("%w: %d", ErrSocks5ConnectFailed, resp[1])
|
|
}
|
|
|
|
return nil
|
|
}
|