Files
olcrtc/internal/protect/protect.go
2026-05-16 04:06:55 +03:00

152 lines
4.4 KiB
Go

// Package protect provides functions to protect sockets from VPN routing.
package protect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"regexp"
"strings"
"syscall"
"time"
"github.com/gorilla/websocket"
)
const (
defaultDialTimeout = 10 * time.Second
defaultKeepAlive = 30 * time.Second
defaultIdleConnTimeout = 30 * time.Second
defaultTLSHandshake = 10 * time.Second
defaultResponseHeader = 10 * time.Second
defaultWebSocketTimeout = 10 * time.Second
defaultHTTPClientTimeout = 30 * time.Second
defaultStatusBodyLimit = 1024
)
var (
sensitiveFieldRE = regexp.MustCompile(
`(?i)((?:access[_-]?token|room[_-]?token|token|credentials)"?\s*[:=]\s*"?)` +
`[^",\s}]+`,
)
sensitiveBearerRE = regexp.MustCompile(`(?i)(bearer\s+)[A-Za-z0-9._~+/=-]+`)
)
// Protector is called with a socket file descriptor before connect.
// On Android, this calls VpnService.protect(fd) to bypass VPN routing.
var Protector func(fd int) bool //nolint:gochecknoglobals // package-level state intentional
func controlFunc(network, _ string, c syscall.RawConn) error {
if Protector == nil {
return nil
}
var err error
controlErr := c.Control(func(fd uintptr) {
if !Protector(int(fd)) {
err = &net.OpError{Op: "protect", Net: network, Err: net.ErrClosed}
}
})
if controlErr != nil {
return fmt.Errorf("control failed: %w", controlErr)
}
return err
}
// NewDialer returns a net.Dialer that calls Protector on each new socket.
func NewDialer() *net.Dialer {
return &net.Dialer{
Timeout: defaultDialTimeout,
KeepAlive: defaultKeepAlive,
Control: controlFunc,
}
}
// NewTLSConfig returns the shared TLS policy for provider HTTP/WebSocket clients.
func NewTLSConfig() *tls.Config {
return &tls.Config{MinVersion: tls.VersionTLS12}
}
// NewHTTPTransport returns an HTTP transport using protected sockets and sane timeouts.
func NewHTTPTransport() *http.Transport {
dialer := NewDialer()
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
TLSClientConfig: NewTLSConfig(),
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: defaultIdleConnTimeout,
TLSHandshakeTimeout: defaultTLSHandshake,
ResponseHeaderTimeout: defaultResponseHeader,
}
}
// NewHTTPClient returns an http.Client using protected sockets.
func NewHTTPClient() *http.Client {
return &http.Client{
Transport: NewHTTPTransport(),
Timeout: defaultHTTPClientTimeout,
}
}
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
if handshakeTimeout <= 0 {
handshakeTimeout = defaultWebSocketTimeout
}
return websocket.Dialer{
NetDialContext: DialContext,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: NewTLSConfig(),
HandshakeTimeout: handshakeTimeout,
}
}
// StatusError formats an upstream HTTP error while bounding and redacting the body.
func StatusError(base error, resp *http.Response, limit int64) error {
if limit <= 0 {
limit = defaultStatusBodyLimit
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
bodyText := RedactSensitive(strings.TrimSpace(string(body)))
if bodyText == "" {
return fmt.Errorf("%w: status %d", base, resp.StatusCode)
}
return fmt.Errorf("%w: status %d: %s", base, resp.StatusCode, bodyText)
}
// RedactSensitive removes common token-like values from provider error text.
func RedactSensitive(text string) string {
text = sensitiveBearerRE.ReplaceAllString(text, "${1}<redacted>")
return sensitiveFieldRE.ReplaceAllString(text, "${1}<redacted>")
}
// DialContext dials using a protected socket.
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := NewDialer().DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
// ProxyDialer implements golang.org/x/net/proxy.Dialer for pion ICE.
type ProxyDialer struct{}
// Dial connects to the address on the named network using a protected socket.
func (d *ProxyDialer) Dial(network, addr string) (net.Conn, error) {
conn, err := NewDialer().Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("dial failed: %w", err)
}
return conn, nil
}
// NewProxyDialer returns a proxy.Dialer that protects ICE sockets.
func NewProxyDialer() *ProxyDialer {
return &ProxyDialer{}
}