feat(protect): add DNS retry logic to HTTP client

This commit is contained in:
zarazaex69
2026-05-24 17:23:04 +03:00
parent fe5dbb55b1
commit cb6fe0980d
2 changed files with 48 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ package protect
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
@@ -84,14 +85,53 @@ func NewHTTPTransport() *http.Transport {
}
}
// NewHTTPClient returns an http.Client using protected sockets.
// NewHTTPClient returns an http.Client using protected sockets with DNS retry.
func NewHTTPClient() *http.Client {
return &http.Client{
Transport: NewHTTPTransport(),
Transport: &retryTransport{base: NewHTTPTransport()},
Timeout: defaultHTTPClientTimeout,
}
}
// retryTransport retries requests on transient DNS/dial errors.
type retryTransport struct {
base http.RoundTripper
}
func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
const maxRetries = 3
var resp *http.Response
var err error
for i := range maxRetries {
if i > 0 {
time.Sleep(time.Duration(i) * 500 * time.Millisecond)
}
resp, err = t.base.RoundTrip(req)
if err == nil || !isRetriableError(err) {
return resp, err
}
}
return resp, err
}
func isRetriableError(err error) bool {
if err == nil {
return false
}
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
return opErr.Timeout() || strings.Contains(opErr.Error(), "connection refused")
}
s := err.Error()
return strings.Contains(s, "no such host") ||
strings.Contains(s, "connection reset") ||
strings.Contains(s, "i/o timeout")
}
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
if handshakeTimeout <= 0 {

View File

@@ -86,9 +86,13 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
}
client := NewHTTPClient()
tr, ok := client.Transport.(*http.Transport)
rt, ok := client.Transport.(*retryTransport)
if !ok {
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
t.Fatalf("Transport type = %T, want *protect.retryTransport", client.Transport)
}
tr, ok := rt.base.(*http.Transport)
if !ok {
t.Fatalf("base Transport type = %T, want *http.Transport", rt.base)
}
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||