From cb6fe0980d742ee81eafa9df3020f17440787544 Mon Sep 17 00:00:00 2001 From: zarazaex69 Date: Sun, 24 May 2026 17:23:04 +0300 Subject: [PATCH] feat(protect): add DNS retry logic to HTTP client --- internal/protect/protect.go | 44 ++++++++++++++++++++++++++++++-- internal/protect/protect_test.go | 8 ++++-- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/internal/protect/protect.go b/internal/protect/protect.go index 00b38e3..b3754b0 100644 --- a/internal/protect/protect.go +++ b/internal/protect/protect.go @@ -4,6 +4,7 @@ package protect import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -84,14 +85,53 @@ func NewHTTPTransport() *http.Transport { } } -// NewHTTPClient returns an http.Client using protected sockets. +// NewHTTPClient returns an http.Client using protected sockets with DNS retry. func NewHTTPClient() *http.Client { return &http.Client{ - Transport: NewHTTPTransport(), + Transport: &retryTransport{base: NewHTTPTransport()}, Timeout: defaultHTTPClientTimeout, } } +// retryTransport retries requests on transient DNS/dial errors. +type retryTransport struct { + base http.RoundTripper +} + +func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { + const maxRetries = 3 + var resp *http.Response + var err error + for i := range maxRetries { + if i > 0 { + time.Sleep(time.Duration(i) * 500 * time.Millisecond) + } + resp, err = t.base.RoundTrip(req) + if err == nil || !isRetriableError(err) { + return resp, err + } + } + return resp, err +} + +func isRetriableError(err error) bool { + if err == nil { + return false + } + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return true + } + var opErr *net.OpError + if errors.As(err, &opErr) { + return opErr.Timeout() || strings.Contains(opErr.Error(), "connection refused") + } + s := err.Error() + return strings.Contains(s, "no such host") || + strings.Contains(s, "connection reset") || + strings.Contains(s, "i/o timeout") +} + // NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy. func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer { if handshakeTimeout <= 0 { diff --git a/internal/protect/protect_test.go b/internal/protect/protect_test.go index e07a666..1af0cba 100644 --- a/internal/protect/protect_test.go +++ b/internal/protect/protect_test.go @@ -86,9 +86,13 @@ func TestNewDialerAndHTTPClient(t *testing.T) { } client := NewHTTPClient() - tr, ok := client.Transport.(*http.Transport) + rt, ok := client.Transport.(*retryTransport) if !ok { - t.Fatalf("Transport type = %T, want *http.Transport", client.Transport) + t.Fatalf("Transport type = %T, want *protect.retryTransport", client.Transport) + } + tr, ok := rt.base.(*http.Transport) + if !ok { + t.Fatalf("base Transport type = %T, want *http.Transport", rt.base) } if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil || tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||