import time
import threading
[docs]
def retry(retries: int = 3, delay: float = 0.1, retry_on_status_code=None):
"""Middleware that retries failed requests.
Automatically retries requests that fail with specified status codes or
connection errors (status code 0).
Args:
retries (int, optional): Number of retry attempts. Defaults to 3.
delay (float, optional): Delay in seconds between retries. Defaults to 0.1.
retry_on_status_code (set, optional): HTTP status codes to retry on.
Defaults to {500, 502, 503, 504}.
Returns:
function: Middleware function to be passed to HttpClient
"""
# retry must be a function that takes a request and returns a response
# and only happens for the particular status code
retry_on_status_code = retry_on_status_code or {500, 502, 503, 504}
def middleware(next):
def handler(request):
last_response = None
for attempt in range(retries):
response = next(request)
# check if we need to retry based on the status code and Response object
if response.status_code not in retry_on_status_code and response.status_code != 0:
return response
last_response = response
if attempt < retries - 1:
# sleep for the delay
time.sleep(delay)
return last_response
return handler
return middleware
[docs]
def logging(logger=print):
"""Middleware that logs request and response information.
Logs HTTP method, URL, response status code, and elapsed time for each request.
Args:
logger (callable, optional): Function to use for logging. Defaults to print.
Will be called with formatted log strings.
Returns:
function: Middleware function to be passed to HttpClient
"""
def middleware(next):
def handler(request):
start = time.perf_counter()
logger(f"[REQUEST] {request.method} {request.url}")
response = next(request)
elapsed_time = time.perf_counter() - start
logger(
f"[RESPONSE] Status Code: {response.status_code} | "
f"Elapsed Time: {int(elapsed_time * 1000):.2f} ms"
)
return response
return handler
return middleware
[docs]
def timeout(seconds: float = 10.0):
"""Middleware that enforces a timeout during request processing.
This is adapter-agnostic and will works with any adapter that supports the timeout parameter,
such as `requests`, `urllib` or `httpx`.
Args:
seconds (float, optional): Timeout duration in seconds. Defaults to 10.0.
Raises:
TimeoutError: If the request processing doesn't compelete within the specified timeout duration.
Returns:
function: Middleware function to be passed to HttpClient
"""
def middleware(next):
def handler(request):
result = [None]
exception = [None]
def target():
try:
result[0] = next(request)
except Exception as e:
exception[0] = e
# run the request processing in a separate thread to allow timeout enforcement
thread = threading.Thread(target=target, daemon=True)
thread.start()
thread.join(seconds)
if thread.is_alive():
raise TimeoutError(f"Request timeout after {seconds} seconds")
# if the request processing thread has completed but an exception was raised, re-raise the exception
if exception[0] is not None:
raise exception[0]
return result[0]
return handler
# sentinel value to detect whether the timeout middleware is present in the client or not
# it will be read by HttpClient class specifically in the `_build_chain` method to suppress the request.timeout
# preventing double-timer conflict which can cause unxpected behavior
middleware._is_timeout_middleware = True
return middleware
[docs]
def rate_limit(calls: int = 10, period: float = 1.0):
"""Middleware that enforces a rate limit across all requests through this client.
It will uses a sliding-window counter (and it's thread-safe) to ensure no more
than `calls` requests are processed within any `period` seconds window. It won't
drop requests silently but will delay them until they can be processed again
Args:
calls (int): Maximum number of calls allowed within the specified period.
defaults to 10.
period (float): Time window in seconds for the rate limit. Defaults to 1.0
Returns:
function: Middleware function to be passed to HttpClient
Raises:
ValueError: If `calls` is not postive or `period` is not positive.
"""
if calls < 1:
raise ValueError("calls must be positive number")
if period <= 0:
raise ValueError("period must be more than 0")
lock = threading.Lock()
# list of timestamps of the recent calls
# used to track the number of calls within the current period
call_times = []
def middleware(next):
def handler(request):
with lock:
now = time.time()
# remove timestamps that are outside the current period (evicted)
evicted = now - period
while call_times and call_times[0] < evicted:
call_times.pop(0)
# if we have reached the rate limit, calculate the time to wait
# until the oldest call timestamp is outside the current period
if len(call_times) >= calls:
sleep_until = period - (now - call_times[0])
if sleep_until > 0:
time.sleep(sleep_until)
# re-evict timestamps that are now outside the current period after sleeping
now = time.time()
evicted = now - period
while call_times and call_times[0] < evicted:
call_times.pop(0)
# record the current call timestamp
call_times.append(time.time())
return next(request)
return handler
return middleware