Source code for src.middleware

import time
import threading


[docs] def retry(retries: int = 3, delay: float = 0.1, retry_on_status_code=None): """Middleware that retries failed requests. Automatically retries requests that fail with specified status codes or connection errors (status code 0). Args: retries (int, optional): Number of retry attempts. Defaults to 3. delay (float, optional): Delay in seconds between retries. Defaults to 0.1. retry_on_status_code (set, optional): HTTP status codes to retry on. Defaults to {500, 502, 503, 504}. Returns: function: Middleware function to be passed to HttpClient """ # retry must be a function that takes a request and returns a response # and only happens for the particular status code retry_on_status_code = retry_on_status_code or {500, 502, 503, 504} def middleware(next): def handler(request): last_response = None for attempt in range(retries): response = next(request) # check if we need to retry based on the status code and Response object if response.status_code not in retry_on_status_code and response.status_code != 0: return response last_response = response if attempt < retries - 1: # sleep for the delay time.sleep(delay) return last_response return handler return middleware
[docs] def logging(logger=print): """Middleware that logs request and response information. Logs HTTP method, URL, response status code, and elapsed time for each request. Args: logger (callable, optional): Function to use for logging. Defaults to print. Will be called with formatted log strings. Returns: function: Middleware function to be passed to HttpClient """ def middleware(next): def handler(request): start = time.perf_counter() logger(f"[REQUEST] {request.method} {request.url}") response = next(request) elapsed_time = time.perf_counter() - start logger( f"[RESPONSE] Status Code: {response.status_code} | " f"Elapsed Time: {int(elapsed_time * 1000):.2f} ms" ) return response return handler return middleware
[docs] def timeout(seconds: float = 10.0): """Middleware that enforces a timeout during request processing. This is adapter-agnostic and will works with any adapter that supports the timeout parameter, such as `requests`, `urllib` or `httpx`. Args: seconds (float, optional): Timeout duration in seconds. Defaults to 10.0. Raises: TimeoutError: If the request processing doesn't compelete within the specified timeout duration. Returns: function: Middleware function to be passed to HttpClient """ def middleware(next): def handler(request): result = [None] exception = [None] def target(): try: result[0] = next(request) except Exception as e: exception[0] = e # run the request processing in a separate thread to allow timeout enforcement thread = threading.Thread(target=target, daemon=True) thread.start() thread.join(seconds) if thread.is_alive(): raise TimeoutError(f"Request timeout after {seconds} seconds") # if the request processing thread has completed but an exception was raised, re-raise the exception if exception[0] is not None: raise exception[0] return result[0] return handler # sentinel value to detect whether the timeout middleware is present in the client or not # it will be read by HttpClient class specifically in the `_build_chain` method to suppress the request.timeout # preventing double-timer conflict which can cause unxpected behavior middleware._is_timeout_middleware = True return middleware
[docs] def rate_limit(calls: int = 10, period: float = 1.0): """Middleware that enforces a rate limit across all requests through this client. It will uses a sliding-window counter (and it's thread-safe) to ensure no more than `calls` requests are processed within any `period` seconds window. It won't drop requests silently but will delay them until they can be processed again Args: calls (int): Maximum number of calls allowed within the specified period. defaults to 10. period (float): Time window in seconds for the rate limit. Defaults to 1.0 Returns: function: Middleware function to be passed to HttpClient Raises: ValueError: If `calls` is not postive or `period` is not positive. """ if calls < 1: raise ValueError("calls must be positive number") if period <= 0: raise ValueError("period must be more than 0") lock = threading.Lock() # list of timestamps of the recent calls # used to track the number of calls within the current period call_times = [] def middleware(next): def handler(request): with lock: now = time.time() # remove timestamps that are outside the current period (evicted) evicted = now - period while call_times and call_times[0] < evicted: call_times.pop(0) # if we have reached the rate limit, calculate the time to wait # until the oldest call timestamp is outside the current period if len(call_times) >= calls: sleep_until = period - (now - call_times[0]) if sleep_until > 0: time.sleep(sleep_until) # re-evict timestamps that are now outside the current period after sleeping now = time.time() evicted = now - period while call_times and call_times[0] < evicted: call_times.pop(0) # record the current call timestamp call_times.append(time.time()) return next(request) return handler return middleware