mirror of
https://github.com/openlibrecommunity/olcrtc.git
synced 2026-05-26 07:08:11 +00:00
feat(protect): add DNS retry logic to HTTP client
This commit is contained in:
@@ -4,6 +4,7 @@ package protect
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -84,14 +85,53 @@ func NewHTTPTransport() *http.Transport {
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPClient returns an http.Client using protected sockets.
|
||||
// NewHTTPClient returns an http.Client using protected sockets with DNS retry.
|
||||
func NewHTTPClient() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: NewHTTPTransport(),
|
||||
Transport: &retryTransport{base: NewHTTPTransport()},
|
||||
Timeout: defaultHTTPClientTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// retryTransport retries requests on transient DNS/dial errors.
|
||||
type retryTransport struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
const maxRetries = 3
|
||||
var resp *http.Response
|
||||
var err error
|
||||
for i := range maxRetries {
|
||||
if i > 0 {
|
||||
time.Sleep(time.Duration(i) * 500 * time.Millisecond)
|
||||
}
|
||||
resp, err = t.base.RoundTrip(req)
|
||||
if err == nil || !isRetriableError(err) {
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func isRetriableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) {
|
||||
return true
|
||||
}
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
return opErr.Timeout() || strings.Contains(opErr.Error(), "connection refused")
|
||||
}
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "no such host") ||
|
||||
strings.Contains(s, "connection reset") ||
|
||||
strings.Contains(s, "i/o timeout")
|
||||
}
|
||||
|
||||
// NewWebSocketDialer returns a WebSocket dialer using protected sockets and shared TLS policy.
|
||||
func NewWebSocketDialer(handshakeTimeout time.Duration) websocket.Dialer {
|
||||
if handshakeTimeout <= 0 {
|
||||
|
||||
@@ -86,9 +86,13 @@ func TestNewDialerAndHTTPClient(t *testing.T) {
|
||||
}
|
||||
|
||||
client := NewHTTPClient()
|
||||
tr, ok := client.Transport.(*http.Transport)
|
||||
rt, ok := client.Transport.(*retryTransport)
|
||||
if !ok {
|
||||
t.Fatalf("Transport type = %T, want *http.Transport", client.Transport)
|
||||
t.Fatalf("Transport type = %T, want *protect.retryTransport", client.Transport)
|
||||
}
|
||||
tr, ok := rt.base.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("base Transport type = %T, want *http.Transport", rt.base)
|
||||
}
|
||||
if tr.Proxy == nil || tr.DialContext == nil || tr.TLSClientConfig == nil ||
|
||||
tr.TLSClientConfig.MinVersion != tls.VersionTLS12 || !tr.ForceAttemptHTTP2 || tr.MaxIdleConns != 10 ||
|
||||
|
||||
Reference in New Issue
Block a user