diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml new file mode 100644 index 0000000000..984f2d0162 --- /dev/null +++ b/.github/workflows/typecheck.yml @@ -0,0 +1,32 @@ +name: Type Check + +on: [push, pull_request] + +permissions: + contents: read + +jobs: + typecheck: + runs-on: ubuntu-24.04 + timeout-minutes: 10 + strategy: + matrix: + python-version: ["3.10", "3.14"] + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install pip==26.0.1 + python -m pip install -e . --group typecheck + + - name: Run pyright + run: python -m pyright src/requests/ diff --git a/pyproject.toml b/pyproject.toml index dcde263dd8..7c89412645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,14 +51,20 @@ Source = "https://github.com/psf/requests" [project.optional-dependencies] security = [] socks = ["PySocks>=1.5.6, !=1.5.7"] -use_chardet_on_py3 = ["chardet>=3.0.2,<8"] +use_chardet_on_py3 = ["chardet>=3.0.2,<7"] + +[dependency-groups] test = [ "pytest-httpbin==2.1.0", "pytest-cov", "pytest-mock", "pytest-xdist", "PySocks>=1.5.6, !=1.5.7", - "pytest>=3" + "pytest>=3", +] +typecheck = [ + "pyright", + "typing_extensions", ] [tool.setuptools] @@ -100,3 +106,8 @@ addopts = "--doctest-modules" doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS" minversion = "6.2" testpaths = ["tests"] + + +[tool.pyright] +include = ["src/requests"] +typeCheckingMode = "strict" diff --git a/src/requests/__init__.py b/src/requests/__init__.py index 8ecd8b8149..0a90fb0941 100644 --- a/src/requests/__init__.py +++ b/src/requests/__init__.py @@ -38,6 +38,8 @@ :license: Apache 2.0, see LICENSE for more details. """ +from __future__ import annotations + import warnings import urllib3 @@ -50,21 +52,25 @@ charset_normalizer_version = None try: - from chardet import __version__ as chardet_version + from chardet import __version__ as chardet_version # type: ignore[import-not-found] except ImportError: chardet_version = None -def check_compatibility(urllib3_version, chardet_version, charset_normalizer_version): - urllib3_version = urllib3_version.split(".") - assert urllib3_version != ["dev"] # Verify urllib3 isn't installed from git. +def check_compatibility( + urllib3_version: str, + chardet_version: str | None, + charset_normalizer_version: str | None, +) -> None: + urllib3_version_list = urllib3_version.split(".") + assert urllib3_version_list != ["dev"] # Verify urllib3 isn't installed from git. # Sometimes, urllib3 only reports its version as 16.1. - if len(urllib3_version) == 2: - urllib3_version.append("0") + if len(urllib3_version_list) == 2: + urllib3_version_list.append("0") # Check urllib3 for compatibility. - major, minor, patch = urllib3_version # noqa: F811 + major, minor, patch = urllib3_version_list # noqa: F811 major, minor, patch = int(major), int(minor), int(patch) # urllib3 >= 1.21.1 assert major >= 1 @@ -90,28 +96,28 @@ def check_compatibility(urllib3_version, chardet_version, charset_normalizer_ver ) -def _check_cryptography(cryptography_version): +def _check_cryptography(cryptography_version: str) -> None: # cryptography < 1.3.4 try: - cryptography_version = list(map(int, cryptography_version.split("."))) + cryptography_version_list = list(map(int, cryptography_version.split("."))) except ValueError: return - if cryptography_version < [1, 3, 4]: - warning = ( - f"Old version of cryptography ({cryptography_version}) may cause slowdown." - ) + if cryptography_version_list < [1, 3, 4]: + warning = f"Old version of cryptography ({cryptography_version_list}) may cause slowdown." warnings.warn(warning, RequestsDependencyWarning) # Check imported dependencies for compatibility. try: check_compatibility( - urllib3.__version__, chardet_version, charset_normalizer_version + urllib3.__version__, # type: ignore[reportPrivateImportUsage] + chardet_version, # type: ignore[reportUnknownArgumentType] + charset_normalizer_version, ) except (AssertionError, ValueError): warnings.warn( - f"urllib3 ({urllib3.__version__}) or chardet " + f"urllib3 ({urllib3.__version__}) or chardet " # type: ignore[reportPrivateImportUsage] f"({chardet_version})/charset_normalizer ({charset_normalizer_version}) " "doesn't match a supported version!", RequestsDependencyWarning, @@ -132,9 +138,11 @@ def _check_cryptography(cryptography_version): pyopenssl.inject_into_urllib3() # Check cryptography version - from cryptography import __version__ as cryptography_version + from cryptography import ( # type: ignore[reportMissingImports] + __version__ as cryptography_version, # type: ignore[reportUnknownVariableType] + ) - _check_cryptography(cryptography_version) + _check_cryptography(cryptography_version) # type: ignore[reportUnknownArgumentType] except ImportError: pass @@ -177,6 +185,34 @@ def _check_cryptography(cryptography_version): from .sessions import Session, session from .status_codes import codes +__all__ = ( + "ConnectionError", + "ConnectTimeout", + "HTTPError", + "JSONDecodeError", + "PreparedRequest", + "ReadTimeout", + "Request", + "RequestException", + "Response", + "Session", + "Timeout", + "TooManyRedirects", + "URLRequired", + "codes", + "delete", + "get", + "head", + "options", + "packages", + "patch", + "post", + "put", + "request", + "session", + "utils", +) + logging.getLogger(__name__).addHandler(NullHandler()) # FileModeWarnings go off per the default. diff --git a/src/requests/_internal_utils.py b/src/requests/_internal_utils.py index 8c7c05190c..613420eabe 100644 --- a/src/requests/_internal_utils.py +++ b/src/requests/_internal_utils.py @@ -23,7 +23,7 @@ } -def to_native_string(string, encoding="ascii"): +def to_native_string(string: str | bytes, encoding: str = "ascii") -> str: """Given a string object, regardless of type, returns a representation of that string in the native string type, encoding and decoding where necessary. This assumes ASCII unless told otherwise. @@ -36,7 +36,7 @@ def to_native_string(string, encoding="ascii"): return out -def unicode_is_ascii(u_string): +def unicode_is_ascii(u_string: str) -> bool: """Determine if unicode string only contains ASCII characters. :param str u_string: unicode string to check. Must be unicode diff --git a/src/requests/_types.py b/src/requests/_types.py new file mode 100644 index 0000000000..1b37c7b492 --- /dev/null +++ b/src/requests/_types.py @@ -0,0 +1,170 @@ +""" +requests._types +~~~~~~~~~~~~~~~ + +This module contains type aliases used internally by the Requests library. +These types are not part of the public API and must not be relied upon +by external code. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping, MutableMapping +from typing import ( + TYPE_CHECKING, + Any, + Protocol, + TypeAlias, + TypeVar, + runtime_checkable, +) + +_T_co = TypeVar("_T_co", covariant=True) + + +@runtime_checkable +class SupportsRead(Protocol[_T_co]): + def read(self, length: int = ..., /) -> _T_co: ... + + +@runtime_checkable +class SupportsItems(Protocol): + def items(self) -> Iterable[tuple[Any, Any]]: ... + + +# These are needed at runtime for default_hooks() return type +HookType: TypeAlias = Callable[["Response"], Any] +HooksInputType: TypeAlias = Mapping[str, Iterable[HookType] | HookType] + + +def is_prepared(request: PreparedRequest) -> TypeIs[_ValidatedRequest]: + """Verify a PreparedRequest has been fully prepared.""" + if TYPE_CHECKING: + return request.url is not None and request.method is not None + # noop at runtime to avoid AssertionError + return True + + +if TYPE_CHECKING: + from http.cookiejar import CookieJar + from typing import TypeAlias, TypedDict + + from typing_extensions import TypeIs # move to typing when Python >= 3.13 + + from .auth import AuthBase + from .cookies import RequestsCookieJar + from .models import PreparedRequest, Response + from .structures import CaseInsensitiveDict + + class _ValidatedRequest(PreparedRequest): + """Subtype asserting a PreparedRequest has been fully prepared before calling. + + The override suppression is required because mutable attribute types are + invariant (Liskov), but we only narrow after preparation is complete. This + is the explicit contract for Requests but Python's typing doesn't have a + better way to represent the requirement. + """ + + url: str # type: ignore[reportIncompatibleVariableOverride] + method: str # type: ignore[reportIncompatibleVariableOverride] + + # Type aliases for core API concepts (ordered by request() signature) + UriType: TypeAlias = str | bytes + + _ParamsMappingKeyType: TypeAlias = str | bytes | int | float + _ParamsMappingValueType: TypeAlias = ( + str | bytes | int | float | Iterable[str | bytes | int | float] | None + ) + ParamsType: TypeAlias = ( + Mapping[_ParamsMappingKeyType, _ParamsMappingValueType] + | tuple[tuple[_ParamsMappingKeyType, _ParamsMappingValueType], ...] + | Iterable[tuple[_ParamsMappingKeyType, _ParamsMappingValueType]] + | str + | bytes + | None + ) + + KVDataType: TypeAlias = Iterable[tuple[Any, Any]] | Mapping[Any, Any] + + EncodableDataType: TypeAlias = KVDataType | str | bytes | SupportsRead[str | bytes] + + DataType: TypeAlias = ( + KVDataType + | Iterable[bytes | str] + | str + | bytes + | SupportsRead[str | bytes] + | None + ) + + BodyType: TypeAlias = ( + bytes | str | Iterable[bytes | str] | SupportsRead[bytes | str] | None + ) + + HeadersType: TypeAlias = CaseInsensitiveDict[str] | Mapping[str, str | bytes] + HeadersUpdateType: TypeAlias = Mapping[str, str | bytes | None] + + CookiesType: TypeAlias = RequestsCookieJar | Mapping[str, str] + + # Building blocks for FilesType + _FileName: TypeAlias = str | None + _FileContent: TypeAlias = SupportsRead[str | bytes] | str | bytes + _FileSpecBasic: TypeAlias = tuple[_FileName, _FileContent] + _FileSpecWithContentType: TypeAlias = tuple[_FileName, _FileContent, str] + _FileSpecWithHeaders: TypeAlias = tuple[ + _FileName, _FileContent, str, CaseInsensitiveDict[str] | Mapping[str, str] + ] + _FileSpec: TypeAlias = ( + _FileContent | _FileSpecBasic | _FileSpecWithContentType | _FileSpecWithHeaders + ) + FilesType: TypeAlias = ( + Mapping[str, _FileSpec] | Iterable[tuple[str, _FileSpec]] | None + ) + + AuthType: TypeAlias = ( + tuple[str, str] | AuthBase | Callable[[PreparedRequest], PreparedRequest] | None + ) + + TimeoutType: TypeAlias = float | tuple[float | None, float | None] | None + ProxiesType: TypeAlias = MutableMapping[str, str] + HooksType: TypeAlias = dict[str, list[HookType]] | None + VerifyType: TypeAlias = bool | str + CertType: TypeAlias = str | tuple[str, str] | None + JsonType: TypeAlias = ( + None | bool | int | float | str | list["JsonType"] | dict[str, "JsonType"] + ) + + # TypedDicts for Unpack kwargs (PEP 692) + + class BaseRequestKwargs(TypedDict, total=False): + headers: Mapping[str, str | bytes] | None + cookies: RequestsCookieJar | CookieJar | dict[str, str] | None + files: FilesType + auth: AuthType + timeout: TimeoutType + allow_redirects: bool + proxies: dict[str, str] | None + hooks: HooksType + stream: bool | None + verify: VerifyType | None + cert: CertType + + class RequestKwargs(BaseRequestKwargs, total=False): + """kwargs for request(), options(), head(), delete().""" + + params: ParamsType + data: DataType + json: JsonType + + class GetKwargs(BaseRequestKwargs, total=False): + data: DataType + json: JsonType + + class PostKwargs(BaseRequestKwargs, total=False): + params: ParamsType + + class DataKwargs(BaseRequestKwargs, total=False): + """kwargs for put(), patch().""" + + params: ParamsType + json: JsonType diff --git a/src/requests/adapters.py b/src/requests/adapters.py index 130154f8dc..07f087435a 100644 --- a/src/requests/adapters.py +++ b/src/requests/adapters.py @@ -6,10 +6,13 @@ and maintain connections. """ +from __future__ import annotations + import os.path -import socket # noqa: F401 +import socket # noqa: F401 # type: ignore[reportUnusedImport] import typing import warnings +from typing import Any from urllib3.exceptions import ( ClosedPoolError, @@ -30,7 +33,7 @@ from urllib3.util import parse_url from urllib3.util.retry import Retry -from .auth import _basic_auth_str +from .auth import _basic_auth_str # type: ignore[reportPrivateUsage] from .compat import basestring, urlparse from .cookies import extract_cookies_to_jar from .exceptions import ( @@ -58,16 +61,21 @@ ) try: - from urllib3.contrib.socks import SOCKSProxyManager + from urllib3.contrib.socks import SOCKSProxyManager # type: ignore[assignment] except ImportError: - def SOCKSProxyManager(*args, **kwargs): + def SOCKSProxyManager(*args: Any, **kwargs: Any) -> None: raise InvalidSchema("Missing dependencies for SOCKS support.") if typing.TYPE_CHECKING: + from urllib3.connectionpool import HTTPConnectionPool + from urllib3.poolmanager import PoolManager as _PoolManager + + from ._types import CertType, TimeoutType, VerifyType from .models import PreparedRequest +from ._types import is_prepared DEFAULT_POOLBLOCK = False DEFAULT_POOLSIZE = 10 @@ -76,13 +84,13 @@ def SOCKSProxyManager(*args, **kwargs): def _urllib3_request_context( - request: "PreparedRequest", - verify: "bool | str | None", - client_cert: "tuple[str, str] | str | None", - poolmanager: "PoolManager", -) -> "(dict[str, typing.Any], dict[str, typing.Any])": - host_params = {} - pool_kwargs = {} + request: PreparedRequest, + verify: bool | str | None, + client_cert: tuple[str, str] | str | None, + poolmanager: PoolManager, +) -> tuple[dict[str, Any], dict[str, Any]]: + host_params: dict[str, Any] = {} + pool_kwargs: dict[str, Any] = {} parsed_request_url = urlparse(request.url) scheme = parsed_request_url.scheme.lower() port = parsed_request_url.port @@ -115,12 +123,18 @@ def _urllib3_request_context( class BaseAdapter: """The Base Transport Adapter""" - def __init__(self): + def __init__(self) -> None: super().__init__() def send( - self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None - ): + self, + request: PreparedRequest, + stream: bool = False, + timeout: TimeoutType = None, + verify: VerifyType = True, + cert: CertType = None, + proxies: dict[str, str] | None = None, + ) -> Response: """Sends PreparedRequest object. Returns Response object. :param request: The :class:`PreparedRequest ` being sent. @@ -137,7 +151,7 @@ def send( """ raise NotImplementedError - def close(self): + def close(self) -> None: """Cleans up adapter specific items.""" raise NotImplementedError @@ -169,7 +183,7 @@ class HTTPAdapter(BaseAdapter): >>> s.mount('http://', a) """ - __attrs__ = [ + __attrs__: list[str] = [ "max_retries", "config", "_pool_connections", @@ -177,13 +191,21 @@ class HTTPAdapter(BaseAdapter): "_pool_block", ] + max_retries: Retry + config: dict[str, Any] + proxy_manager: dict[str, Any] + _pool_connections: int + _pool_maxsize: int + _pool_block: bool + poolmanager: _PoolManager + def __init__( self, - pool_connections=DEFAULT_POOLSIZE, - pool_maxsize=DEFAULT_POOLSIZE, - max_retries=DEFAULT_RETRIES, - pool_block=DEFAULT_POOLBLOCK, - ): + pool_connections: int = DEFAULT_POOLSIZE, + pool_maxsize: int = DEFAULT_POOLSIZE, + max_retries: int | Retry = DEFAULT_RETRIES, + pool_block: bool = DEFAULT_POOLBLOCK, + ) -> None: if max_retries == DEFAULT_RETRIES: self.max_retries = Retry(0, read=False) else: @@ -199,10 +221,10 @@ def __init__( self.init_poolmanager(pool_connections, pool_maxsize, block=pool_block) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: return {attr: getattr(self, attr, None) for attr in self.__attrs__} - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Can't handle by adding 'proxy_manager' to self.__attrs__ because # self.poolmanager uses a lambda function, which isn't pickleable. self.proxy_manager = {} @@ -216,8 +238,12 @@ def __setstate__(self, state): ) def init_poolmanager( - self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs - ): + self, + connections: int, + maxsize: int, + block: bool = DEFAULT_POOLBLOCK, + **pool_kwargs: Any, + ) -> None: """Initializes a urllib3 PoolManager. This method should not be called from user code, and is only @@ -241,7 +267,7 @@ def init_poolmanager( **pool_kwargs, ) - def proxy_manager_for(self, proxy, **proxy_kwargs): + def proxy_manager_for(self, proxy: str, **proxy_kwargs: Any) -> Any: """Return urllib3 ProxyManager for the given proxy. This method should not be called from user code, and is only @@ -279,7 +305,9 @@ def proxy_manager_for(self, proxy, **proxy_kwargs): return manager - def cert_verify(self, conn, url, verify, cert): + def cert_verify( + self, conn: Any, url: str, verify: VerifyType, cert: CertType + ) -> None: """Verify a SSL certificate. This method should not be called from user code, and is only exposed for use when subclassing the :class:`HTTPAdapter `. @@ -335,7 +363,7 @@ def cert_verify(self, conn, url, verify, cert): f"Could not find the TLS key file, invalid path: {conn.key_file}" ) - def build_response(self, req, resp): + def build_response(self, req: PreparedRequest, resp: Any) -> Response: """Builds a :class:`Response ` object from a urllib3 response. This should not be called from user code, and is only exposed for use when subclassing the @@ -345,10 +373,11 @@ def build_response(self, req, resp): :param resp: The urllib3 response object. :rtype: requests.Response """ + assert is_prepared(req) response = Response() # Fallback to None if there's no status_code, for whatever reason. - response.status_code = getattr(resp, "status", None) + response.status_code = getattr(resp, "status", None) # type: ignore[assignment] # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) @@ -372,7 +401,9 @@ def build_response(self, req, resp): return response - def build_connection_pool_key_attributes(self, request, verify, cert=None): + def build_connection_pool_key_attributes( + self, request: PreparedRequest, verify: VerifyType, cert: CertType = None + ) -> tuple[dict[str, Any], dict[str, Any]]: """Build the PoolKey attributes used by urllib3 to return a connection. This looks at the PreparedRequest, the user-specified verify value, @@ -422,7 +453,13 @@ def build_connection_pool_key_attributes(self, request, verify, cert=None): """ return _urllib3_request_context(request, verify, cert, self.poolmanager) - def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): + def get_connection_with_tls_context( + self, + request: PreparedRequest, + verify: VerifyType, + proxies: dict[str, str] | None = None, + cert: CertType = None, + ) -> HTTPConnectionPool: """Returns a urllib3 connection for the given request and TLS settings. This should not be called from user code, and is only exposed for use when subclassing the :class:`HTTPAdapter `. @@ -440,8 +477,10 @@ def get_connection_with_tls_context(self, request, verify, proxies=None, cert=No (optional) Any user-provided SSL certificate to be used for client authentication (a.k.a., mTLS). :rtype: - urllib3.ConnectionPool + urllib3.HTTPConnectionPool """ + assert is_prepared(request) + proxy = select_proxy(request.url, proxies) try: host_params, pool_kwargs = self.build_connection_pool_key_attributes( @@ -471,7 +510,9 @@ def get_connection_with_tls_context(self, request, verify, proxies=None, cert=No return conn - def get_connection(self, url, proxies=None): + def get_connection( + self, url: str, proxies: dict[str, str] | None = None + ) -> HTTPConnectionPool: """DEPRECATED: Users should move to `get_connection_with_tls_context` for all subclasses of HTTPAdapter using Requests>=2.32.2. @@ -481,7 +522,7 @@ def get_connection(self, url, proxies=None): :param url: The URL to connect to. :param proxies: (optional) A Requests-style dictionary of proxies used on this request. - :rtype: urllib3.ConnectionPool + :rtype: urllib3.HTTPConnectionPool """ warnings.warn( ( @@ -512,7 +553,7 @@ def get_connection(self, url, proxies=None): return conn - def close(self): + def close(self) -> None: """Disposes of any internal state. Currently, this closes the PoolManager and any active ProxyManager, @@ -522,7 +563,9 @@ def close(self): for proxy in self.proxy_manager.values(): proxy.clear() - def request_url(self, request, proxies): + def request_url( + self, request: PreparedRequest, proxies: dict[str, str] | None + ) -> str: """Obtain the url to use when making the final request. If the message is being sent through a HTTP proxy, the full URL has to @@ -536,6 +579,8 @@ def request_url(self, request, proxies): :param proxies: A dictionary of schemes or schemes and hosts to proxy URLs. :rtype: str """ + assert is_prepared(request) + proxy = select_proxy(request.url, proxies) scheme = urlparse(request.url).scheme @@ -554,7 +599,7 @@ def request_url(self, request, proxies): return url - def add_headers(self, request, **kwargs): + def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: """Add any headers needed by the connection. As of v2.0 this does nothing by default, but is left for overriding by users that subclass the :class:`HTTPAdapter `. @@ -568,7 +613,7 @@ def add_headers(self, request, **kwargs): """ pass - def proxy_headers(self, proxy): + def proxy_headers(self, proxy: str) -> dict[str, str]: """Returns a dictionary of the headers to add to any request sent through a proxy. This works with urllib3 magic to ensure that they are correctly sent to the proxy, rather than in a tunnelled request if @@ -581,7 +626,7 @@ def proxy_headers(self, proxy): :param proxy: The url of the proxy being used for this request. :rtype: dict """ - headers = {} + headers: dict[str, str] = {} username, password = get_auth_from_url(proxy) if username: @@ -590,8 +635,14 @@ def proxy_headers(self, proxy): return headers def send( - self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None - ): + self, + request: PreparedRequest, + stream: bool = False, + timeout: TimeoutType = None, + verify: VerifyType = True, + cert: CertType = None, + proxies: dict[str, str] | None = None, + ) -> Response: """Sends PreparedRequest object. Returns Response object. :param request: The :class:`PreparedRequest ` being sent. @@ -608,6 +659,8 @@ def send( :rtype: requests.Response """ + assert is_prepared(request) + try: conn = self.get_connection_with_tls_context( request, verify, proxies=proxies, cert=cert @@ -631,29 +684,29 @@ def send( if isinstance(timeout, tuple): try: connect, read = timeout - timeout = TimeoutSauce(connect=connect, read=read) + resolved_timeout = TimeoutSauce(connect=connect, read=read) except ValueError: raise ValueError( f"Invalid timeout {timeout}. Pass a (connect, read) timeout tuple, " f"or a single float to set both timeouts to the same value." ) elif isinstance(timeout, TimeoutSauce): - pass + resolved_timeout = timeout else: - timeout = TimeoutSauce(connect=timeout, read=timeout) + resolved_timeout = TimeoutSauce(connect=timeout, read=timeout) try: resp = conn.urlopen( method=request.method, url=url, - body=request.body, + body=request.body, # type: ignore[arg-type] # urllib3 stubs don't accept Iterable[bytes | str] headers=request.headers, redirect=False, assert_same_host=False, preload_content=False, decode_content=False, retries=self.max_retries, - timeout=timeout, + timeout=resolved_timeout, chunked=chunked, ) diff --git a/src/requests/api.py b/src/requests/api.py index 5960744552..50af0b9feb 100644 --- a/src/requests/api.py +++ b/src/requests/api.py @@ -8,10 +8,28 @@ :license: Apache2, see LICENSE for more details. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + from . import sessions +from .models import Response + +if TYPE_CHECKING: + from typing_extensions import Unpack # move to typing when Python >= 3.12 + + from ._types import ( + DataKwargs, + DataType, + GetKwargs, + JsonType, + ParamsType, + PostKwargs, + RequestKwargs, + ) -def request(method, url, **kwargs): +def request(method: str, url: str, **kwargs: Unpack[RequestKwargs]) -> Response: """Constructs and sends a :class:`Request `. :param method: method for the new :class:`Request` object: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``, or ``DELETE``. @@ -59,7 +77,7 @@ def request(method, url, **kwargs): return session.request(method=method, url=url, **kwargs) -def get(url, params=None, **kwargs): +def get(url: str, params: ParamsType = None, **kwargs: Unpack[GetKwargs]) -> Response: r"""Sends a GET request. :param url: URL for the new :class:`Request` object. @@ -73,7 +91,7 @@ def get(url, params=None, **kwargs): return request("get", url, params=params, **kwargs) -def options(url, **kwargs): +def options(url: str, **kwargs: Unpack[RequestKwargs]) -> Response: r"""Sends an OPTIONS request. :param url: URL for the new :class:`Request` object. @@ -85,7 +103,7 @@ def options(url, **kwargs): return request("options", url, **kwargs) -def head(url, **kwargs): +def head(url: str, **kwargs: Unpack[RequestKwargs]) -> Response: r"""Sends a HEAD request. :param url: URL for the new :class:`Request` object. @@ -100,7 +118,9 @@ def head(url, **kwargs): return request("head", url, **kwargs) -def post(url, data=None, json=None, **kwargs): +def post( + url: str, data: DataType = None, json: JsonType = None, **kwargs: Unpack[PostKwargs] +) -> Response: r"""Sends a POST request. :param url: URL for the new :class:`Request` object. @@ -115,7 +135,7 @@ def post(url, data=None, json=None, **kwargs): return request("post", url, data=data, json=json, **kwargs) -def put(url, data=None, **kwargs): +def put(url: str, data: DataType = None, **kwargs: Unpack[DataKwargs]) -> Response: r"""Sends a PUT request. :param url: URL for the new :class:`Request` object. @@ -130,7 +150,7 @@ def put(url, data=None, **kwargs): return request("put", url, data=data, **kwargs) -def patch(url, data=None, **kwargs): +def patch(url: str, data: DataType = None, **kwargs: Unpack[DataKwargs]) -> Response: r"""Sends a PATCH request. :param url: URL for the new :class:`Request` object. @@ -145,7 +165,7 @@ def patch(url, data=None, **kwargs): return request("patch", url, data=data, **kwargs) -def delete(url, **kwargs): +def delete(url: str, **kwargs: Unpack[RequestKwargs]) -> Response: r"""Sends a DELETE request. :param url: URL for the new :class:`Request` object. diff --git a/src/requests/auth.py b/src/requests/auth.py index c39b645189..bb7b577efb 100644 --- a/src/requests/auth.py +++ b/src/requests/auth.py @@ -5,6 +5,8 @@ This module contains the authentication handlers for Requests. """ +from __future__ import annotations + import hashlib import os import re @@ -12,17 +14,24 @@ import time import warnings from base64 import b64encode +from typing import TYPE_CHECKING, Any, Final, cast, overload from ._internal_utils import to_native_string from .compat import basestring, str, urlparse from .cookies import extract_cookies_to_jar from .utils import parse_dict_header -CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded" -CONTENT_TYPE_MULTI_PART = "multipart/form-data" +if TYPE_CHECKING: + from http.cookiejar import CookieJar + + from .adapters import HTTPAdapter + from .models import PreparedRequest, Response + +CONTENT_TYPE_FORM_URLENCODED: Final = "application/x-www-form-urlencoded" +CONTENT_TYPE_MULTI_PART: Final = "multipart/form-data" -def _basic_auth_str(username, password): +def _basic_auth_str(username: bytes | str, password: bytes | str) -> str: """Returns a Basic Auth string.""" # "I want us to put a big-ol' comment on top of it that @@ -32,7 +41,7 @@ def _basic_auth_str(username, password): # # These are here solely to maintain backwards compatibility # for things like ints. This will be removed in 3.0.0. - if not isinstance(username, basestring): + if not isinstance(username, basestring): # type: ignore[reportUnnecessaryIsInstance] # runtime guard for non-str/bytes warnings.warn( "Non-string usernames will no longer be supported in Requests " f"3.0.0. Please convert the object you've passed in ({username!r}) to " @@ -42,7 +51,7 @@ def _basic_auth_str(username, password): ) username = str(username) - if not isinstance(password, basestring): + if not isinstance(password, basestring): # type: ignore[reportUnnecessaryIsInstance] # runtime guard for non-str/bytes warnings.warn( "Non-string passwords will no longer be supported in Requests " f"3.0.0. Please convert the object you've passed in ({type(password)!r}) to " @@ -69,18 +78,26 @@ def _basic_auth_str(username, password): class AuthBase: """Base class that all auth implementations derive from""" - def __call__(self, r): + def __call__(self, r: PreparedRequest) -> PreparedRequest: raise NotImplementedError("Auth hooks must be callable.") class HTTPBasicAuth(AuthBase): """Attaches HTTP Basic Authentication to the given Request object.""" - def __init__(self, username, password): + username: bytes | str + password: bytes | str + + @overload + def __init__(self, username: str, password: str) -> None: ... + @overload + def __init__(self, username: bytes, password: bytes) -> None: ... + + def __init__(self, username: bytes | str, password: bytes | str) -> None: self.username = username self.password = password - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return all( [ self.username == getattr(other, "username", None), @@ -88,10 +105,10 @@ def __eq__(self, other): ] ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other - def __call__(self, r): + def __call__(self, r: PreparedRequest) -> PreparedRequest: r.headers["Authorization"] = _basic_auth_str(self.username, self.password) return r @@ -99,7 +116,7 @@ def __call__(self, r): class HTTPProxyAuth(HTTPBasicAuth): """Attaches HTTP Proxy Authentication to a given Request object.""" - def __call__(self, r): + def __call__(self, r: PreparedRequest) -> PreparedRequest: r.headers["Proxy-Authorization"] = _basic_auth_str(self.username, self.password) return r @@ -107,13 +124,27 @@ def __call__(self, r): class HTTPDigestAuth(AuthBase): """Attaches HTTP Digest Authentication to the given Request object.""" - def __init__(self, username, password): + username: bytes | str + password: bytes | str + _thread_local: threading.local + last_nonce: str + nonce_count: int + chal: dict[str, str] + pos: int | None + num_401_calls: int | None + + @overload + def __init__(self, username: str, password: str) -> None: ... + @overload + def __init__(self, username: bytes, password: bytes) -> None: ... + + def __init__(self, username: bytes | str, password: bytes | str) -> None: self.username = username self.password = password # Keep state in per-thread local storage self._thread_local = threading.local() - def init_per_thread_state(self): + def init_per_thread_state(self) -> None: # Ensure state is initialized just once per-thread if not hasattr(self._thread_local, "init"): self._thread_local.init = True @@ -123,7 +154,7 @@ def init_per_thread_state(self): self._thread_local.pos = None self._thread_local.num_401_calls = None - def build_digest_header(self, method, url): + def build_digest_header(self, method: str, url: str) -> str | None: """ :rtype: str """ @@ -142,7 +173,7 @@ def build_digest_header(self, method, url): # lambdas assume digest modules are imported at the top level if _algorithm == "MD5" or _algorithm == "MD5-SESS": - def md5_utf8(x): + def md5_utf8(x: str | bytes) -> str: if isinstance(x, str): x = x.encode("utf-8") return hashlib.md5(x).hexdigest() @@ -150,7 +181,7 @@ def md5_utf8(x): hash_utf8 = md5_utf8 elif _algorithm == "SHA": - def sha_utf8(x): + def sha_utf8(x: str | bytes) -> str: if isinstance(x, str): x = x.encode("utf-8") return hashlib.sha1(x).hexdigest() @@ -158,7 +189,7 @@ def sha_utf8(x): hash_utf8 = sha_utf8 elif _algorithm == "SHA-256": - def sha256_utf8(x): + def sha256_utf8(x: str | bytes) -> str: if isinstance(x, str): x = x.encode("utf-8") return hashlib.sha256(x).hexdigest() @@ -166,18 +197,19 @@ def sha256_utf8(x): hash_utf8 = sha256_utf8 elif _algorithm == "SHA-512": - def sha512_utf8(x): + def sha512_utf8(x: str | bytes) -> str: if isinstance(x, str): x = x.encode("utf-8") return hashlib.sha512(x).hexdigest() hash_utf8 = sha512_utf8 - KD = lambda s, d: hash_utf8(f"{s}:{d}") # noqa:E731 - if hash_utf8 is None: return None + def KD(s: str, d: str) -> str: + return hash_utf8(f"{s}:{d}") + # XXX not implemented yet entdig = None p_parsed = urlparse(url) @@ -204,7 +236,7 @@ def sha512_utf8(x): cnonce = hashlib.sha1(s).hexdigest()[:16] if _algorithm == "MD5-SESS": - HA1 = hash_utf8(f"{HA1}:{nonce}:{cnonce}") + HA1 = hash_utf8(f"{HA1}:{nonce}:{cnonce}") # type: ignore[reportConstantRedefinition] # RFC 2617 terminology if not qop: respdig = KD(HA1, f"{nonce}:{HA2}") @@ -233,12 +265,12 @@ def sha512_utf8(x): return f"Digest {base}" - def handle_redirect(self, r, **kwargs): + def handle_redirect(self, r: Response, **kwargs: Any) -> None: """Reset num_401_calls counter on redirects.""" if r.is_redirect: self._thread_local.num_401_calls = 1 - def handle_401(self, r, **kwargs): + def handle_401(self, r: Response, **kwargs: Any) -> Response: """ Takes the given response and tries digest-auth, if needed. @@ -254,7 +286,8 @@ def handle_401(self, r, **kwargs): if self._thread_local.pos is not None: # Rewind the file position indicator of the body to where # it was to resend the request. - r.request.body.seek(self._thread_local.pos) + if (seek := getattr(r.request.body, "seek", None)) is not None: + seek(self._thread_local.pos) s_auth = r.headers.get("www-authenticate", "") if "digest" in s_auth.lower() and self._thread_local.num_401_calls < 2: @@ -267,13 +300,17 @@ def handle_401(self, r, **kwargs): r.content r.close() prep = r.request.copy() - extract_cookies_to_jar(prep._cookies, r.request, r.raw) - prep.prepare_cookies(prep._cookies) + cookie_jar = cast("CookieJar", prep._cookies) # type: ignore[reportPrivateUsage] + extract_cookies_to_jar(cookie_jar, r.request, r.raw) + prep.prepare_cookies(cookie_jar) - prep.headers["Authorization"] = self.build_digest_header( - prep.method, prep.url + _digest_auth = self.build_digest_header( + cast(str, prep.method), cast(str, prep.url) ) - _r = r.connection.send(prep, **kwargs) + if _digest_auth: + prep.headers["Authorization"] = _digest_auth + conn = cast("HTTPAdapter", r.connection) + _r = conn.send(prep, **kwargs) _r.history.append(r) _r.request = prep @@ -282,15 +319,19 @@ def handle_401(self, r, **kwargs): self._thread_local.num_401_calls = 1 return r - def __call__(self, r): + def __call__(self, r: PreparedRequest) -> PreparedRequest: # Initialize per-thread state, if needed self.init_per_thread_state() # If we have a saved nonce, skip the 401 if self._thread_local.last_nonce: - r.headers["Authorization"] = self.build_digest_header(r.method, r.url) - try: - self._thread_local.pos = r.body.tell() - except AttributeError: + _digest_auth = self.build_digest_header( + cast(str, r.method), cast(str, r.url) + ) + if _digest_auth: + r.headers["Authorization"] = _digest_auth + if (tell := getattr(r.body, "tell", None)) is not None: + self._thread_local.pos = tell() + else: # In the case of HTTPDigestAuth being reused and the body of # the previous request was a file-like object, pos has the # file position of the previous body. Ensure it's set to @@ -302,7 +343,7 @@ def __call__(self, r): return r - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return all( [ self.username == getattr(other, "username", None), @@ -310,5 +351,5 @@ def __eq__(self, other): ] ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other diff --git a/src/requests/compat.py b/src/requests/compat.py index 7f9d754350..deab3c091f 100644 --- a/src/requests/compat.py +++ b/src/requests/compat.py @@ -7,13 +7,20 @@ compatibility until the next major version. """ +# pyright: reportUnusedImport=false + +from __future__ import annotations + import importlib import sys +from types import ModuleType # ------- # urllib3 # ------- -from urllib3 import __version__ as urllib3_version +from urllib3 import ( + __version__ as urllib3_version, # type: ignore[reportPrivateImportUsage] +) # Detect which major version of urllib3 is being used. try: @@ -27,7 +34,7 @@ # ------------------- -def _resolve_char_detection(): +def _resolve_char_detection() -> ModuleType | None: """Find supported character detection libraries.""" chardet = None for lib in ("chardet", "charset_normalizer"): @@ -57,14 +64,14 @@ def _resolve_char_detection(): # json/simplejson module import resolution has_simplejson = False try: - import simplejson as json + import simplejson as json # type: ignore[import-not-found] has_simplejson = True except ImportError: import json if has_simplejson: - from simplejson import JSONDecodeError + from simplejson import JSONDecodeError # type: ignore[import-not-found] else: from json import JSONDecodeError @@ -95,7 +102,7 @@ def _resolve_char_detection(): getproxies_environment, parse_http_list, proxy_bypass, - proxy_bypass_environment, + proxy_bypass_environment, # type: ignore[attr-defined] # https://github.com/python/cpython/issues/145331 ) builtin_str = str diff --git a/src/requests/cookies.py b/src/requests/cookies.py index f69d0cda9e..78aa510e52 100644 --- a/src/requests/cookies.py +++ b/src/requests/cookies.py @@ -7,17 +7,24 @@ requests.utils imports from here, so be careful with imports. """ +from __future__ import annotations + import calendar import copy import time +from collections.abc import Iterator, MutableMapping +from http.cookiejar import Cookie, CookieJar, CookiePolicy +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload from ._internal_utils import to_native_string -from .compat import Morsel, MutableMapping, cookielib, urlparse, urlunparse +from .compat import Morsel, cookielib, urlparse, urlunparse + +if TYPE_CHECKING: + from _typeshed import SupportsKeysAndGetItem -try: - import threading -except ImportError: - import dummy_threading as threading + from .models import PreparedRequest + +import threading class MockRequest: @@ -32,31 +39,33 @@ class MockRequest: probably want `get_cookie_header`, defined below. """ - def __init__(self, request): + type: str + + def __init__(self, request: PreparedRequest) -> None: self._r = request - self._new_headers = {} - self.type = urlparse(self._r.url).scheme + self._new_headers: dict[str, str] = {} + self.type = urlparse(self._r.url).scheme # type: ignore[assignment] # TODO(typing): str|bytes URL handling - def get_type(self): + def get_type(self) -> str: return self.type - def get_host(self): - return urlparse(self._r.url).netloc + def get_host(self) -> str: + return urlparse(self._r.url).netloc # type: ignore[return-value] # TODO(typing): str|bytes URL handling - def get_origin_req_host(self): + def get_origin_req_host(self) -> str: return self.get_host() - def get_full_url(self): + def get_full_url(self) -> str: # Only return the response's URL if the user hadn't set the Host # header if not self._r.headers.get("Host"): - return self._r.url + return cast(str, self._r.url) # If they did set it, retrieve it and reconstruct the expected domain host = to_native_string(self._r.headers["Host"], encoding="utf-8") parsed = urlparse(self._r.url) # Reconstruct the URL as we expect it return urlunparse( - [ + [ # type: ignore[arg-type] # TODO(typing): str|bytes URL handling parsed.scheme, host, parsed.path, @@ -66,37 +75,37 @@ def get_full_url(self): ] ) - def is_unverifiable(self): + def is_unverifiable(self) -> bool: return True - def has_header(self, name): + def has_header(self, name: str) -> bool: return name in self._r.headers or name in self._new_headers - def get_header(self, name, default=None): + def get_header(self, name: str, default: str | None = None) -> str | None: return self._r.headers.get(name, self._new_headers.get(name, default)) - def add_header(self, key, val): + def add_header(self, key: str, val: str) -> None: """cookiejar has no legitimate use for this method; add it back if you find one.""" raise NotImplementedError( "Cookie headers should be added with add_unredirected_header()" ) - def add_unredirected_header(self, name, value): + def add_unredirected_header(self, name: str, value: str) -> None: self._new_headers[name] = value - def get_new_headers(self): + def get_new_headers(self) -> dict[str, str]: return self._new_headers @property - def unverifiable(self): + def unverifiable(self) -> bool: return self.is_unverifiable() @property - def origin_req_host(self): + def origin_req_host(self) -> str: return self.get_origin_req_host() @property - def host(self): + def host(self) -> str: return self.get_host() @@ -107,21 +116,23 @@ class MockResponse: the way `http.cookiejar` expects to see them. """ - def __init__(self, headers): + def __init__(self, headers: Any) -> None: """Make a MockResponse for `cookiejar` to read. :param headers: a httplib.HTTPMessage or analogous carrying the headers """ self._headers = headers - def info(self): + def info(self) -> Any: return self._headers - def getheaders(self, name): + def getheaders(self, name: str) -> Any: self._headers.getheaders(name) -def extract_cookies_to_jar(jar, request, response): +def extract_cookies_to_jar( + jar: CookieJar, request: PreparedRequest, response: Any +) -> None: """Extract the cookies from the response into a CookieJar. :param jar: http.cookiejar.CookieJar (not necessarily a RequestsCookieJar) @@ -134,26 +145,28 @@ def extract_cookies_to_jar(jar, request, response): req = MockRequest(request) # pull out the HTTPMessage with the headers and put it in the mock: res = MockResponse(response._original_response.msg) - jar.extract_cookies(res, req) + jar.extract_cookies(res, req) # type: ignore[arg-type] -def get_cookie_header(jar, request): +def get_cookie_header(jar: CookieJar, request: PreparedRequest) -> str | None: """ Produce an appropriate Cookie header string to be sent with `request`, or None. :rtype: str """ r = MockRequest(request) - jar.add_cookie_header(r) + jar.add_cookie_header(r) # type: ignore[arg-type] return r.get_new_headers().get("Cookie") -def remove_cookie_by_name(cookiejar, name, domain=None, path=None): +def remove_cookie_by_name( + cookiejar: CookieJar, name: str, domain: str | None = None, path: str | None = None +) -> None: """Unsets a cookie by name, by default over all domains and paths. Wraps CookieJar.clear(), is O(n). """ - clearables = [] + clearables: list[tuple[str, str, str]] = [] for cookie in cookiejar: if cookie.name != name: continue @@ -173,7 +186,7 @@ class CookieConflictError(RuntimeError): """ -class RequestsCookieJar(cookielib.CookieJar, MutableMapping): +class RequestsCookieJar(CookieJar, MutableMapping[str, str | None]): # type: ignore[misc] """Compatibility class; is a http.cookiejar.CookieJar, but exposes a dict interface. @@ -191,7 +204,15 @@ class RequestsCookieJar(cookielib.CookieJar, MutableMapping): .. warning:: dictionary operations that are normally O(1) may be O(n). """ - def get(self, name, default=None, domain=None, path=None): + _policy: CookiePolicy + + def get( # type: ignore[override] + self, + name: str, + default: str | None = None, + domain: str | None = None, + path: str | None = None, + ) -> str | None: """Dict-like get() that also supports optional domain and path args in order to resolve naming collisions from using one cookie jar over multiple domains. @@ -203,7 +224,9 @@ def get(self, name, default=None, domain=None, path=None): except KeyError: return default - def set(self, name, value, **kwargs): + def set( + self, name: str, value: str | Morsel[dict[str, str]] | None, **kwargs: Any + ) -> Cookie | None: """Dict-like set() that also supports optional domain and path args in order to resolve naming collisions from using one cookie jar over multiple domains. @@ -222,7 +245,7 @@ def set(self, name, value, **kwargs): self.set_cookie(c) return c - def iterkeys(self): + def iterkeys(self) -> Iterator[str]: """Dict-like iterkeys() that returns an iterator of names of cookies from the jar. @@ -231,7 +254,7 @@ def iterkeys(self): for cookie in iter(self): yield cookie.name - def keys(self): + def keys(self) -> list[str]: # type: ignore[override] """Dict-like keys() that returns a list of names of cookies from the jar. @@ -239,7 +262,7 @@ def keys(self): """ return list(self.iterkeys()) - def itervalues(self): + def itervalues(self) -> Iterator[str | None]: """Dict-like itervalues() that returns an iterator of values of cookies from the jar. @@ -248,7 +271,7 @@ def itervalues(self): for cookie in iter(self): yield cookie.value - def values(self): + def values(self) -> list[str | None]: # type: ignore[override] """Dict-like values() that returns a list of values of cookies from the jar. @@ -256,7 +279,7 @@ def values(self): """ return list(self.itervalues()) - def iteritems(self): + def iteritems(self) -> Iterator[tuple[str, str | None]]: """Dict-like iteritems() that returns an iterator of name-value tuples from the jar. @@ -265,7 +288,7 @@ def iteritems(self): for cookie in iter(self): yield cookie.name, cookie.value - def items(self): + def items(self) -> list[tuple[str, str | None]]: # type: ignore[override] """Dict-like items() that returns a list of name-value tuples from the jar. Allows client-code to call ``dict(RequestsCookieJar)`` and get a vanilla python dict of key value pairs. @@ -274,43 +297,45 @@ def items(self): """ return list(self.iteritems()) - def list_domains(self): + def list_domains(self) -> list[str]: """Utility method to list all the domains in the jar.""" - domains = [] + domains: list[str] = [] for cookie in iter(self): if cookie.domain not in domains: domains.append(cookie.domain) return domains - def list_paths(self): + def list_paths(self) -> list[str]: """Utility method to list all the paths in the jar.""" - paths = [] + paths: list[str] = [] for cookie in iter(self): if cookie.path not in paths: paths.append(cookie.path) return paths - def multiple_domains(self): + def multiple_domains(self) -> bool: """Returns True if there are multiple domains in the jar. Returns False otherwise. :rtype: bool """ - domains = [] + domains: list[str] = [] for cookie in iter(self): - if cookie.domain is not None and cookie.domain in domains: + if cookie.domain is not None and cookie.domain in domains: # type: ignore[reportUnnecessaryComparison] # defensive check return True domains.append(cookie.domain) return False # there is only one domain in jar - def get_dict(self, domain=None, path=None): + def get_dict( + self, domain: str | None = None, path: str | None = None + ) -> dict[str, str | None]: """Takes as an argument an optional domain and path and returns a plain old Python dict of name-value pairs of cookies that meet the requirements. :rtype: dict """ - dictionary = {} + dictionary: dict[str, str | None] = {} for cookie in iter(self): if (domain is None or cookie.domain == domain) and ( path is None or cookie.path == path @@ -318,13 +343,17 @@ def get_dict(self, domain=None, path=None): dictionary[cookie.name] = cookie.value return dictionary - def __contains__(self, name): + def __iter__(self) -> Iterator[Cookie]: # type: ignore[override] + """RequestCookieJar's __iter__ comes from CookieJar not MutableMapping.""" + return super().__iter__() + + def __contains__(self, name: object) -> bool: try: return super().__contains__(name) except CookieConflictError: return True - def __getitem__(self, name): + def __getitem__(self, name: str) -> str | None: """Dict-like __getitem__() for compatibility with client code. Throws exception if there are more than one cookie with name. In that case, use the more explicit get() method instead. @@ -333,29 +362,33 @@ def __getitem__(self, name): """ return self._find_no_duplicates(name) - def __setitem__(self, name, value): + def __setitem__( + self, name: str, value: str | Morsel[dict[str, str]] | None + ) -> None: """Dict-like __setitem__ for compatibility with client code. Throws exception if there is already a cookie of that name in the jar. In that case, use the more explicit set() method instead. """ self.set(name, value) - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: """Deletes a cookie given a name. Wraps ``http.cookiejar.CookieJar``'s ``remove_cookie_by_name()``. """ remove_cookie_by_name(self, name) - def set_cookie(self, cookie, *args, **kwargs): + def set_cookie(self, cookie: Cookie, *args: Any, **kwargs: Any) -> None: if ( - hasattr(cookie.value, "startswith") - and cookie.value.startswith('"') - and cookie.value.endswith('"') + (value := cookie.value) is not None + and value.startswith('"') + and value.endswith('"') ): - cookie.value = cookie.value.replace('\\"', "") + cookie.value = value.replace('\\"', "") return super().set_cookie(cookie, *args, **kwargs) - def update(self, other): + def update( # type: ignore[override] + self, other: CookieJar | SupportsKeysAndGetItem[str, str] + ) -> None: """Updates this jar with cookies from another CookieJar or dict-like""" if isinstance(other, cookielib.CookieJar): for cookie in other: @@ -363,7 +396,9 @@ def update(self, other): else: super().update(other) - def _find(self, name, domain=None, path=None): + def _find( + self, name: str, domain: str | None = None, path: str | None = None + ) -> str | None: """Requests uses this method internally to get cookie values. If there are conflicting cookies, _find arbitrarily chooses one. @@ -383,7 +418,9 @@ def _find(self, name, domain=None, path=None): raise KeyError(f"name={name!r}, domain={domain!r}, path={path!r}") - def _find_no_duplicates(self, name, domain=None, path=None): + def _find_no_duplicates( + self, name: str, domain: str | None = None, path: str | None = None + ) -> str: """Both ``__get_item__`` and ``get`` call this function: it's never used elsewhere in Requests. @@ -408,42 +445,42 @@ def _find_no_duplicates(self, name, domain=None, path=None): # we will eventually return this as long as no cookie conflict toReturn = cookie.value - if toReturn: + if toReturn is not None: return toReturn raise KeyError(f"name={name!r}, domain={domain!r}, path={path!r}") - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: """Unlike a normal CookieJar, this class is pickleable.""" state = self.__dict__.copy() # remove the unpickleable RLock object state.pop("_cookies_lock") return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: """Unlike a normal CookieJar, this class is pickleable.""" self.__dict__.update(state) if "_cookies_lock" not in self.__dict__: self._cookies_lock = threading.RLock() - def copy(self): + def copy(self) -> RequestsCookieJar: """Return a copy of this RequestsCookieJar.""" new_cj = RequestsCookieJar() new_cj.set_policy(self.get_policy()) new_cj.update(self) return new_cj - def get_policy(self): + def get_policy(self) -> CookiePolicy: """Return the CookiePolicy instance used.""" return self._policy -def _copy_cookie_jar(jar): +def _copy_cookie_jar(jar: CookieJar | None) -> CookieJar | None: # type: ignore[reportUnusedFunction] # cross-module usage in models.py if jar is None: return None - if hasattr(jar, "copy"): + if copy_method := getattr(jar, "copy", None): # We're dealing with an instance of RequestsCookieJar - return jar.copy() + return copy_method() # We're dealing with a generic CookieJar instance new_jar = copy.copy(jar) new_jar.clear() @@ -452,13 +489,13 @@ def _copy_cookie_jar(jar): return new_jar -def create_cookie(name, value, **kwargs): +def create_cookie(name: str, value: str, **kwargs: Any) -> Cookie: """Make a cookie from underspecified parameters. By default, the pair of `name` and `value` will be set for the domain '' and sent on every request (this is sometimes called a "supercookie"). """ - result = { + result: dict[str, Any] = { "version": 0, "name": name, "value": value, @@ -489,10 +526,10 @@ def create_cookie(name, value, **kwargs): return cookielib.Cookie(**result) -def morsel_to_cookie(morsel): +def morsel_to_cookie(morsel: Morsel[Any]) -> Cookie: """Convert a Morsel object into a Cookie containing the one k/v pair.""" - expires = None + expires: int | None = None if morsel["max-age"]: try: expires = int(time.time() + int(morsel["max-age"])) @@ -518,7 +555,30 @@ def morsel_to_cookie(morsel): ) -def cookiejar_from_dict(cookie_dict, cookiejar=None, overwrite=True): +_CookieJarT = TypeVar("_CookieJarT", bound=CookieJar) + + +@overload +def cookiejar_from_dict( + cookie_dict: dict[str, str] | None, + cookiejar: None = None, + overwrite: bool = True, +) -> RequestsCookieJar: ... + + +@overload +def cookiejar_from_dict( + cookie_dict: dict[str, str] | None, + cookiejar: _CookieJarT, + overwrite: bool = True, +) -> _CookieJarT: ... + + +def cookiejar_from_dict( + cookie_dict: dict[str, str] | None, + cookiejar: CookieJar | None = None, + overwrite: bool = True, +) -> CookieJar: """Returns a CookieJar from a key/value dictionary. :param cookie_dict: Dict of key/values to insert into CookieJar. @@ -539,22 +599,24 @@ def cookiejar_from_dict(cookie_dict, cookiejar=None, overwrite=True): return cookiejar -def merge_cookies(cookiejar, cookies): +def merge_cookies( + cookiejar: CookieJar, cookies: dict[str, str] | CookieJar | None +) -> CookieJar: """Add cookies to cookiejar and returns a merged CookieJar. :param cookiejar: CookieJar object to add the cookies to. :param cookies: Dictionary or CookieJar object to be added. :rtype: CookieJar """ - if not isinstance(cookiejar, cookielib.CookieJar): + if not isinstance(cookiejar, cookielib.CookieJar): # type: ignore[reportUnnecessaryIsInstance] # runtime guard raise ValueError("You can only merge into CookieJar") if isinstance(cookies, dict): cookiejar = cookiejar_from_dict(cookies, cookiejar=cookiejar, overwrite=False) elif isinstance(cookies, cookielib.CookieJar): - try: - cookiejar.update(cookies) - except AttributeError: + if update_method := getattr(cookiejar, "update", None): + update_method(cookies) + else: for cookie_in_jar in cookies: cookiejar.set_cookie(cookie_in_jar) diff --git a/src/requests/exceptions.py b/src/requests/exceptions.py index 6e71506e96..cb5e9510e3 100644 --- a/src/requests/exceptions.py +++ b/src/requests/exceptions.py @@ -5,23 +5,33 @@ This module contains the set of Requests' exceptions. """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + from urllib3.exceptions import HTTPError as BaseHTTPError from .compat import JSONDecodeError as CompatJSONDecodeError +if TYPE_CHECKING: + from .models import PreparedRequest, Request, Response + class RequestException(IOError): """There was an ambiguous exception that occurred while handling your request. """ - def __init__(self, *args, **kwargs): + response: Response | None + request: Request | PreparedRequest | None + + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize RequestException with `request` and `response` objects.""" - response = kwargs.pop("response", None) + response: Response | None = kwargs.pop("response", None) self.response = response self.request = kwargs.pop("request", None) if response is not None and not self.request and hasattr(response, "request"): - self.request = self.response.request + self.request = response.request super().__init__(*args, **kwargs) @@ -32,7 +42,7 @@ class InvalidJSONError(RequestException): class JSONDecodeError(InvalidJSONError, CompatJSONDecodeError): """Couldn't decode the text into json""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Construct the JSONDecodeError instance first with all args. Then use it's args to construct the IOError so that @@ -42,7 +52,7 @@ def __init__(self, *args, **kwargs): CompatJSONDecodeError.__init__(self, *args) InvalidJSONError.__init__(self, *self.args, **kwargs) - def __reduce__(self): + def __reduce__(self) -> tuple[Any, ...] | str: """ The __reduce__ method called when pickling the object must be the one from the JSONDecodeError (be it json/simplejson) diff --git a/src/requests/help.py b/src/requests/help.py index 5d5107895e..9269cc7126 100644 --- a/src/requests/help.py +++ b/src/requests/help.py @@ -1,9 +1,12 @@ """Module containing bug report helper(s).""" +# pyright: reportUnknownMemberType=false + import json import platform import ssl import sys +from typing import Any import idna import urllib3 @@ -16,7 +19,7 @@ charset_normalizer = None try: - import chardet + import chardet # type: ignore[import-not-found] except ImportError: chardet = None @@ -27,8 +30,8 @@ OpenSSL = None cryptography = None else: - import cryptography - import OpenSSL + import cryptography # type: ignore[import-not-found] + import OpenSSL # type: ignore[import-not-found] def _implementation(): @@ -47,11 +50,11 @@ def _implementation(): if implementation == "CPython": implementation_version = platform.python_version() elif implementation == "PyPy": - pypy = sys.pypy_version_info + pypy = sys.pypy_version_info # type: ignore[attr-defined] implementation_version = f"{pypy.major}.{pypy.minor}.{pypy.micro}" - if sys.pypy_version_info.releaselevel != "final": + if sys.pypy_version_info.releaselevel != "final": # type: ignore[attr-defined] implementation_version = "".join( - [implementation_version, sys.pypy_version_info.releaselevel] + [implementation_version, sys.pypy_version_info.releaselevel] # type: ignore[attr-defined] ) elif implementation == "Jython": implementation_version = platform.python_version() # Complete Guess @@ -63,7 +66,7 @@ def _implementation(): return {"name": implementation, "version": implementation_version} -def info(): +def info() -> dict[str, Any]: """Generate information for a bug report.""" try: platform_info = { @@ -77,15 +80,15 @@ def info(): } implementation_info = _implementation() - urllib3_info = {"version": urllib3.__version__} + urllib3_info = {"version": urllib3.__version__} # type: ignore[reportPrivateImportUsage] charset_normalizer_info = {"version": None} - chardet_info = {"version": None} + chardet_info: dict[str, str | None] = {"version": None} if charset_normalizer: charset_normalizer_info = {"version": charset_normalizer.__version__} if chardet: chardet_info = {"version": chardet.__version__} - pyopenssl_info = { + pyopenssl_info: dict[str, str | None] = { "version": None, "openssl_version": "", } @@ -102,7 +105,7 @@ def info(): } system_ssl = ssl.OPENSSL_VERSION_NUMBER - system_ssl_info = {"version": f"{system_ssl:x}" if system_ssl is not None else ""} + system_ssl_info = {"version": f"{system_ssl:x}" if system_ssl is not None else ""} # type: ignore[reportUnnecessaryComparison] return { "platform": platform_info, diff --git a/src/requests/hooks.py b/src/requests/hooks.py index 5976bc7d0f..39d251d469 100644 --- a/src/requests/hooks.py +++ b/src/requests/hooks.py @@ -10,24 +10,39 @@ The response generated from a Request. """ -HOOKS = ["response"] +from __future__ import annotations +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any -def default_hooks(): +from ._types import HooksInputType, HookType + +if TYPE_CHECKING: + from .models import Response + +HOOKS: list[str] = ["response"] + + +def default_hooks() -> dict[str, list[HookType]]: return {event: [] for event in HOOKS} # TODO: response is the only one -def dispatch_hook(key, hooks, hook_data, **kwargs): +def dispatch_hook( + key: str, + hooks: HooksInputType | None, + hook_data: Response, + **kwargs: Any, +) -> Response: """Dispatches a hook dictionary on a given piece of data.""" - hooks = hooks or {} - hooks = hooks.get(key) - if hooks: - if hasattr(hooks, "__call__"): - hooks = [hooks] - for hook in hooks: + hooks_dict = hooks or {} + hook_list: Iterable[HookType] | HookType | None = hooks_dict.get(key) + if hook_list: + if isinstance(hook_list, Callable): + hook_list = [hook_list] + for hook in hook_list: _hook_data = hook(hook_data, **kwargs) if _hook_data is not None: hook_data = _hook_data diff --git a/src/requests/models.py b/src/requests/models.py index 2d043f59cf..33a195d694 100644 --- a/src/requests/models.py +++ b/src/requests/models.py @@ -5,13 +5,24 @@ This module contains the primary objects that power Requests. """ +from __future__ import annotations + import datetime # Import encoding now, to avoid implicit import later. # Implicit import within threads may cause LookupError when standard library is in a ZIP, # such as in Embedded Python. See https://github.com/psf/requests/issues/3578. -import encodings.idna # noqa: F401 +import encodings.idna # noqa: F401 # type: ignore[reportUnusedImport] +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from io import UnsupportedOperation +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + cast, + overload, +) from urllib3.exceptions import ( DecodeError, @@ -25,11 +36,10 @@ from urllib3.util import parse_url from ._internal_utils import to_native_string, unicode_is_ascii +from ._types import SupportsRead from .auth import HTTPBasicAuth from .compat import ( - Callable, JSONDecodeError, - Mapping, basestring, builtin_str, chardet, @@ -39,7 +49,11 @@ urlunparse, ) from .compat import json as complexjson -from .cookies import _copy_cookie_jar, cookiejar_from_dict, get_cookie_header +from .cookies import ( + _copy_cookie_jar, # type: ignore[reportPrivateUsage] + cookiejar_from_dict, + get_cookie_header, +) from .exceptions import ( ChunkedEncodingError, ConnectionError, @@ -68,9 +82,27 @@ to_key_val_list, ) +if TYPE_CHECKING: + from http.cookiejar import CookieJar + + from ._types import ( + AuthType, + BodyType, + DataType, + EncodableDataType, + FilesType, + HooksInputType, + HookType, + JsonType, + KVDataType, + ParamsType, + ) + from .adapters import HTTPAdapter + from .cookies import RequestsCookieJar + #: The set of HTTP status codes that indicate an automatically #: processable redirect. -REDIRECT_STATI = ( +REDIRECT_STATI: Final[tuple[int, ...]] = ( # type: ignore[assignment] codes.moved, # 301 codes.found, # 302 codes.other, # 303 @@ -78,19 +110,21 @@ codes.permanent_redirect, # 308 ) -DEFAULT_REDIRECT_LIMIT = 30 -CONTENT_CHUNK_SIZE = 10 * 1024 -ITER_CHUNK_SIZE = 512 +DEFAULT_REDIRECT_LIMIT: int = 30 +CONTENT_CHUNK_SIZE: int = 10 * 1024 +ITER_CHUNK_SIZE: int = 512 class RequestEncodingMixin: + url: str | None + @property - def path_url(self): + def path_url(self) -> str: """Build the path URL to use.""" - url = [] + url: list[str] = [] - p = urlsplit(self.url) + p = urlsplit(cast(str, self.url)) path = p.path if not path: @@ -106,7 +140,9 @@ def path_url(self): return "".join(url) @staticmethod - def _encode_params(data): + def _encode_params( + data: EncodableDataType, + ) -> str | bytes | SupportsRead[str | bytes]: """Encode parameters in a piece of data. Will successfully encode parameters when passed as a dict or a list of @@ -116,10 +152,10 @@ def _encode_params(data): if isinstance(data, (str, bytes)): return data - elif hasattr(data, "read"): + elif isinstance(data, SupportsRead): return data elif hasattr(data, "__iter__"): - result = [] + result: list[tuple[bytes, bytes]] = [] for k, vs in to_key_val_list(data): if isinstance(vs, basestring) or not hasattr(vs, "__iter__"): vs = [vs] @@ -133,10 +169,12 @@ def _encode_params(data): ) return urlencode(result, doseq=True) else: - return data + return data # type: ignore[return-value] # unreachable for valid DataType @staticmethod - def _encode_files(files, data): + def _encode_files( + files: FilesType, data: KVDataType | str | bytes | None + ) -> tuple[bytes, str]: """Build the body for a multipart/form-data request. Will successfully encode files when passed as a dict or a list of @@ -150,7 +188,7 @@ def _encode_files(files, data): elif isinstance(data, basestring): raise ValueError("Data must not be a string.") - new_fields = [] + new_fields: list[RequestField | tuple[str, bytes]] = [] fields = to_key_val_list(data or {}) files = to_key_val_list(files or {}) @@ -189,14 +227,14 @@ def _encode_files(files, data): if isinstance(fp, (str, bytes, bytearray)): fdata = fp - elif hasattr(fp, "read"): + elif isinstance(fp, SupportsRead): # type: ignore[reportUnnecessaryIsInstance] # defensive check for untyped callers fdata = fp.read() - elif fp is None: + elif fp is None: # type: ignore[reportUnnecessaryComparison] # defensive check for untyped callers continue else: fdata = fp - rf = RequestField(name=k, data=fdata, filename=fn, headers=fh) + rf = RequestField(name=k, data=fdata, filename=fn, headers=fh) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling rf.make_multipart(content_type=ft) new_fields.append(rf) @@ -206,7 +244,9 @@ def _encode_files(files, data): class RequestHooksMixin: - def register_hook(self, event, hook): + hooks: dict[str, list[HookType]] + + def register_hook(self, event: str, hook: Iterable[HookType] | HookType) -> None: """Properly register a hook.""" if event not in self.hooks: @@ -215,9 +255,9 @@ def register_hook(self, event, hook): if isinstance(hook, Callable): self.hooks[event].append(hook) elif hasattr(hook, "__iter__"): - self.hooks[event].extend(h for h in hook if isinstance(h, Callable)) + self.hooks[event].extend(h for h in hook if isinstance(h, Callable)) # type: ignore[reportUnnecessaryIsInstance] # defensive runtime filter - def deregister_hook(self, event, hook): + def deregister_hook(self, event: str, hook: HookType) -> bool: """Deregister a previously registered hook. Returns True if the hook existed, False if not. """ @@ -257,19 +297,29 @@ class Request(RequestHooksMixin): """ + method: str | None + url: str | None + headers: CaseInsensitiveDict[str] | Mapping[str, str | bytes] | None + files: FilesType + data: DataType + json: JsonType + params: ParamsType + auth: AuthType + cookies: RequestsCookieJar | CookieJar | dict[str, str] | None + def __init__( self, - method=None, - url=None, - headers=None, - files=None, - data=None, - params=None, - auth=None, - cookies=None, - hooks=None, - json=None, - ): + method: str | None = None, + url: str | None = None, + headers: Mapping[str, str | bytes] | None = None, + files: FilesType = None, + data: DataType = None, + params: ParamsType = None, + auth: AuthType = None, + cookies: RequestsCookieJar | CookieJar | dict[str, str] | None = None, + hooks: HooksInputType | None = None, + json: JsonType = None, + ) -> None: # Default empty dicts for dict params. data = [] if data is None else data files = [] if files is None else files @@ -291,10 +341,10 @@ def __init__( self.auth = auth self.cookies = cookies - def __repr__(self): + def __repr__(self) -> str: return f"" - def prepare(self): + def prepare(self) -> PreparedRequest: """Constructs a :class:`PreparedRequest ` for transmission and returns it.""" p = PreparedRequest() p.prepare( @@ -333,13 +383,22 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): """ - def __init__(self): + method: str | None + url: str | None + headers: CaseInsensitiveDict[str] + _cookies: RequestsCookieJar | CookieJar | None + body: BodyType + hooks: dict[str, list[HookType]] + _body_position: int | object | None + + def __init__(self) -> None: #: HTTP verb to send to the server. self.method = None #: HTTP URL to send the request to. self.url = None #: dictionary of HTTP headers. - self.headers = None + # TODO: Revisit pattern of None-init for attributes that are always set before use + self.headers = None # type: ignore[assignment] # The `CookieJar` used to create the Cookie header will be stored here # after prepare_cookies is called self._cookies = None @@ -352,19 +411,20 @@ def __init__(self): def prepare( self, - method=None, - url=None, - headers=None, - files=None, - data=None, - params=None, - auth=None, - cookies=None, - hooks=None, - json=None, - ): + method: str | None = None, + url: str | None = None, + headers: Mapping[str, str | bytes] | None = None, + files: FilesType = None, + data: DataType = None, + params: ParamsType = None, + auth: AuthType = None, + cookies: RequestsCookieJar | CookieJar | dict[str, str] | None = None, + hooks: HooksInputType | None = None, + json: JsonType = None, + ) -> None: """Prepares the entire request with the given parameters.""" + url = cast(str, url) self.prepare_method(method) self.prepare_url(url, params) self.prepare_headers(headers) @@ -378,28 +438,28 @@ def prepare( # This MUST go after prepare_auth. Authenticators could add a hook self.prepare_hooks(hooks) - def __repr__(self): + def __repr__(self) -> str: return f"" - def copy(self): + def copy(self) -> PreparedRequest: p = PreparedRequest() p.method = self.method p.url = self.url - p.headers = self.headers.copy() if self.headers is not None else None + p.headers = self.headers.copy() if self.headers is not None else None # type: ignore[assignment] p._cookies = _copy_cookie_jar(self._cookies) p.body = self.body p.hooks = self.hooks p._body_position = self._body_position return p - def prepare_method(self, method): + def prepare_method(self, method: str | None) -> None: """Prepares the given HTTP method.""" self.method = method if self.method is not None: self.method = to_native_string(self.method.upper()) @staticmethod - def _get_idna_encoded_host(host): + def _get_idna_encoded_host(host: str) -> str: import idna try: @@ -408,7 +468,11 @@ def _get_idna_encoded_host(host): raise UnicodeError return host - def prepare_url(self, url, params): + def prepare_url( + self, + url: str, + params: ParamsType, + ) -> None: """Prepares the given HTTP URL.""" #: Accept objects that have string representations. #: We're unable to blindly call unicode/str functions @@ -472,28 +536,34 @@ def prepare_url(self, url, params): if isinstance(params, (str, bytes)): params = to_native_string(params) - enc_params = self._encode_params(params) + if params is not None: + enc_params = self._encode_params(params) + else: + enc_params = "" + if enc_params: if query: query = f"{query}&{enc_params}" else: query = enc_params - url = requote_uri(urlunparse([scheme, netloc, path, None, query, fragment])) + url = requote_uri(urlunparse([scheme, netloc, path, None, query, fragment])) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling self.url = url - def prepare_headers(self, headers): + def prepare_headers(self, headers: Mapping[str, str | bytes] | None) -> None: """Prepares the given HTTP headers.""" self.headers = CaseInsensitiveDict() if headers: for header in headers.items(): # Raise exception on invalid header value. - check_header_validity(header) + check_header_validity(header) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling name, value = header - self.headers[to_native_string(name)] = value + self.headers[to_native_string(name)] = value # type: ignore[arg-type] # TODO(typing): str|bytes URL handling - def prepare_body(self, data, files, json=None): + def prepare_body( + self, data: DataType, files: FilesType, json: JsonType = None + ) -> None: """Prepares the given HTTP body data.""" # Check if file, fo, generator, iterator. @@ -516,14 +586,9 @@ def prepare_body(self, data, files, json=None): if not isinstance(body, bytes): body = body.encode("utf-8") - is_stream = all( - [ - hasattr(data, "__iter__"), - not isinstance(data, (basestring, list, tuple, Mapping)), - ] - ) - - if is_stream: + if isinstance(data, Iterable) and not isinstance( + data, (str, bytes, list, tuple, Mapping) + ): try: length = super_len(data) except (TypeError, AttributeError, UnsupportedOperation): @@ -536,7 +601,7 @@ def prepare_body(self, data, files, json=None): # This will allow us to rewind a file in the event # of a redirect. try: - self._body_position = body.tell() + self._body_position = body.tell() # type: ignore[union-attr] # guarded by getattr check except OSError: # This differentiates from None, allowing us to catch # a failed `tell()` later when trying to rewind the body @@ -554,11 +619,11 @@ def prepare_body(self, data, files, json=None): else: # Multi-part file uploads. if files: - (body, content_type) = self._encode_files(files, data) + (body, content_type) = self._encode_files(files, data) # type: ignore[arg-type] # is_stream filters non-encodable iterables else: if data: - body = self._encode_params(data) - if isinstance(data, basestring) or hasattr(data, "read"): + body = self._encode_params(data) # type: ignore[arg-type] # is_stream filters non-encodable iterables + if isinstance(data, basestring) or isinstance(data, SupportsRead): content_type = None else: content_type = "application/x-www-form-urlencoded" @@ -569,9 +634,9 @@ def prepare_body(self, data, files, json=None): if content_type and ("content-type" not in self.headers): self.headers["Content-Type"] = content_type - self.body = body + self.body = body # type: ignore[assignment] # body transforms from DataType to BodyType - def prepare_content_length(self, body): + def prepare_content_length(self, body: BodyType) -> None: """Prepare Content-Length header based on request method and body""" if body is not None: length = super_len(body) @@ -587,21 +652,28 @@ def prepare_content_length(self, body): # but don't provide one. (i.e. not GET or HEAD) self.headers["Content-Length"] = "0" - def prepare_auth(self, auth, url=""): + def prepare_auth( + self, + auth: AuthType, + url: str = "", + ) -> None: """Prepares the given HTTP auth data.""" # If no Auth is explicitly provided, extract it from the URL first. if auth is None: - url_auth = get_auth_from_url(self.url) + url_auth = get_auth_from_url(cast(str, self.url)) auth = url_auth if any(url_auth) else None if auth: - if isinstance(auth, tuple) and len(auth) == 2: + if isinstance(auth, tuple) and len(auth) == 2: # type: ignore[arg-type] # pyright widens tuple from Callable in AuthType # special-case basic HTTP auth - auth = HTTPBasicAuth(*auth) + auth_handler = HTTPBasicAuth(*auth) # type: ignore[arg-type] # pyright widens tuple from Callable in AuthType + else: + # TODO: can be fixed by flipping the conditionals + auth_handler = cast("Callable[..., PreparedRequest]", auth) # Allow auth to make its changes. - r = auth(self) + r = auth_handler(self) # Update self to reflect the auth changes. self.__dict__.update(r.__dict__) @@ -609,7 +681,9 @@ def prepare_auth(self, auth, url=""): # Recompute Content-Length self.prepare_content_length(self.body) - def prepare_cookies(self, cookies): + def prepare_cookies( + self, cookies: RequestsCookieJar | CookieJar | dict[str, str] | None + ) -> None: """Prepares the given HTTP cookie data. This function eventually generates a ``Cookie`` header from the @@ -625,16 +699,17 @@ def prepare_cookies(self, cookies): else: self._cookies = cookiejar_from_dict(cookies) - cookie_header = get_cookie_header(self._cookies, self) + cookies_jar = cast("CookieJar", self._cookies) + cookie_header = get_cookie_header(cookies_jar, self) if cookie_header is not None: self.headers["Cookie"] = cookie_header - def prepare_hooks(self, hooks): + def prepare_hooks(self, hooks: HooksInputType | None) -> None: """Prepares the given hooks.""" # hooks can be passed as None to the prepare method and to this # method. To prevent iterating over None, simply use an empty list # if hooks is False-y - hooks = hooks or [] + hooks = hooks or {} for event in hooks: self.register_hook(event, hooks[event]) @@ -644,7 +719,22 @@ class Response: server's response to an HTTP request. """ - __attrs__ = [ + _content: bytes | Literal[False] | None + _content_consumed: bool + _next: PreparedRequest | None + status_code: int + headers: CaseInsensitiveDict[str] + raw: Any + url: str + encoding: str | None + history: list[Response] + reason: str | None + cookies: RequestsCookieJar + elapsed: datetime.timedelta + request: PreparedRequest + connection: HTTPAdapter | None + + __attrs__: list[str] = [ "_content", "status_code", "headers", @@ -657,13 +747,13 @@ class Response: "request", ] - def __init__(self): + def __init__(self) -> None: self._content = False self._content_consumed = False self._next = None #: Integer Code of responded HTTP Status, e.g. 404 or 200. - self.status_code = None + self.status_code = None # type: ignore[assignment] #: Case-insensitive Dictionary of Response Headers. #: For example, ``headers['content-encoding']`` will return the @@ -676,7 +766,7 @@ def __init__(self): self.raw = None #: Final URL location of Response. - self.url = None + self.url = None # type: ignore[assignment] #: Encoding to decode with when accessing r.text. self.encoding = None @@ -702,15 +792,15 @@ def __init__(self): #: The :class:`PreparedRequest ` object to which this #: is a response. - self.request = None + self.request = None # type: ignore[assignment] - def __enter__(self): + def __enter__(self) -> Response: return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # Consume everything; accessing the content attribute makes # sure the content has been fully read. if not self._content_consumed: @@ -718,7 +808,7 @@ def __getstate__(self): return {attr: getattr(self, attr, None) for attr in self.__attrs__} - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: for name, value in state.items(): setattr(self, name, value) @@ -726,10 +816,10 @@ def __setstate__(self, state): setattr(self, "_content_consumed", True) setattr(self, "raw", None) - def __repr__(self): + def __repr__(self) -> str: return f"" - def __bool__(self): + def __bool__(self) -> bool: """Returns True if :attr:`status_code` is less than 400. This attribute checks if the status code of the response is between @@ -739,7 +829,7 @@ def __bool__(self): """ return self.ok - def __nonzero__(self): + def __nonzero__(self) -> bool: """Returns True if :attr:`status_code` is less than 400. This attribute checks if the status code of the response is between @@ -749,12 +839,12 @@ def __nonzero__(self): """ return self.ok - def __iter__(self): + def __iter__(self) -> Iterator[bytes]: """Allows you to use a response as an iterator.""" return self.iter_content(128) @property - def ok(self): + def ok(self) -> bool: """Returns True if :attr:`status_code` is less than 400, False if not. This attribute checks if the status code of the response is between @@ -769,14 +859,14 @@ def ok(self): return True @property - def is_redirect(self): + def is_redirect(self) -> bool: """True if this Response is a well-formed HTTP redirect that could have been processed automatically (by :meth:`Session.resolve_redirects`). """ return "location" in self.headers and self.status_code in REDIRECT_STATI @property - def is_permanent_redirect(self): + def is_permanent_redirect(self) -> bool: """True if this Response one of the permanent versions of redirect.""" return "location" in self.headers and self.status_code in ( codes.moved_permanently, @@ -784,12 +874,12 @@ def is_permanent_redirect(self): ) @property - def next(self): + def next(self) -> PreparedRequest | None: """Returns a PreparedRequest for the next request in a redirect chain, if there is one.""" return self._next @property - def apparent_encoding(self): + def apparent_encoding(self) -> str | None: """The apparent encoding, provided by the charset_normalizer or chardet libraries.""" if chardet is not None: return chardet.detect(self.content)["encoding"] @@ -798,7 +888,17 @@ def apparent_encoding(self): # to a standard Python utf-8 str. return "utf-8" - def iter_content(self, chunk_size=1, decode_unicode=False): + @overload + def iter_content( + self, chunk_size: int | None = 1, decode_unicode: Literal[False] = False + ) -> Iterator[bytes]: ... + @overload + def iter_content( + self, chunk_size: int | None = 1, *, decode_unicode: Literal[True] + ) -> Iterator[str | bytes]: ... + def iter_content( + self, chunk_size: int | None = 1, decode_unicode: bool = False + ) -> Iterator[str | bytes]: """Iterates over the response data. When stream=True is set on the request, this avoids reading the content at once into memory for large responses. The chunk size is the number of bytes it should @@ -815,7 +915,7 @@ def iter_content(self, chunk_size=1, decode_unicode=False): available encoding based on the response. """ - def generate(): + def generate() -> Generator[bytes, None, None]: # Special case for urllib3. if hasattr(self.raw, "stream"): try: @@ -840,25 +940,44 @@ def generate(): if self._content_consumed and isinstance(self._content, bool): raise StreamConsumedError() - elif chunk_size is not None and not isinstance(chunk_size, int): + elif chunk_size is not None and not isinstance(chunk_size, int): # type: ignore[reportUnnecessaryIsInstance] # runtime guard for untyped callers raise TypeError( f"chunk_size must be an int, it is instead a {type(chunk_size)}." ) - # simulate reading small chunks of the content - reused_chunks = iter_slices(self._content, chunk_size) - - stream_chunks = generate() - chunks = reused_chunks if self._content_consumed else stream_chunks + if self._content_consumed: + # simulate reading small chunks of the content + content = cast(bytes, self._content) + chunks = iter_slices(content, chunk_size) + else: + chunks = generate() if decode_unicode: chunks = stream_decode_response_unicode(chunks, self) return chunks + @overload + def iter_lines( + self, + chunk_size: int = ITER_CHUNK_SIZE, + decode_unicode: Literal[False] = False, + delimiter: bytes | None = None, + ) -> Iterator[bytes]: ... + @overload def iter_lines( - self, chunk_size=ITER_CHUNK_SIZE, decode_unicode=False, delimiter=None - ): + self, + chunk_size: int = ITER_CHUNK_SIZE, + *, + decode_unicode: Literal[True], + delimiter: str | bytes | None = None, + ) -> Iterator[str | bytes]: ... + def iter_lines( + self, + chunk_size: int = ITER_CHUNK_SIZE, + decode_unicode: bool = False, + delimiter: str | bytes | None = None, + ) -> Iterator[str | bytes]: """Iterates over the response data, one line at a time. When stream=True is set on the request, this avoids reading the content at once into memory for large responses. @@ -866,16 +985,17 @@ def iter_lines( .. note:: This method is not reentrant safe. """ - pending = None + pending: str | bytes | None = None for chunk in self.iter_content( chunk_size=chunk_size, decode_unicode=decode_unicode ): if pending is not None: - chunk = pending + chunk + # TODO: remove cast after iter_lines rewrite + chunk = cast("str | bytes", pending + chunk) # type: ignore[operator] if delimiter: - lines = chunk.split(delimiter) + lines = chunk.split(delimiter) # type: ignore[arg-type] else: lines = chunk.splitlines() @@ -890,7 +1010,7 @@ def iter_lines( yield pending @property - def content(self): + def content(self) -> bytes | None: """Content of the response, in bytes.""" if self._content is False: @@ -909,7 +1029,7 @@ def content(self): return self._content @property - def text(self): + def text(self) -> str: """Content of the response, in unicode. If Response.encoding is None, encoding will be guessed using @@ -934,7 +1054,7 @@ def text(self): # Decode unicode from given encoding. try: - content = str(self.content, encoding, errors="replace") + content = str(self.content, encoding or "utf-8", errors="replace") except (LookupError, TypeError): # A LookupError is raised if the encoding was not found which could # indicate a misspelling or similar mistake. @@ -946,7 +1066,7 @@ def text(self): return content - def json(self, **kwargs): + def json(self, **kwargs: Any) -> Any: r"""Decodes the JSON response body (if any) as a Python object. This may return a dictionary, list, etc. depending on what is in the response. @@ -982,23 +1102,24 @@ def json(self, **kwargs): raise RequestsJSONDecodeError(e.msg, e.doc, e.pos) @property - def links(self): + def links(self) -> dict[str, dict[str, str]]: """Returns the parsed header links of the response, if any.""" header = self.headers.get("link") - resolved_links = {} + resolved_links: dict[str, dict[str, str]] = {} if header: links = parse_header_links(header) for link in links: key = link.get("rel") or link.get("url") - resolved_links[key] = link + if key is not None: + resolved_links[key] = link return resolved_links - def raise_for_status(self): + def raise_for_status(self) -> None: """Raises :class:`HTTPError`, if one occurred.""" http_error_msg = "" @@ -1027,7 +1148,7 @@ def raise_for_status(self): if http_error_msg: raise HTTPError(http_error_msg, response=self) - def close(self): + def close(self) -> None: """Releases the connection back to the pool. Once this method has been called the underlying ``raw`` object must not be accessed again. diff --git a/src/requests/py.typed b/src/requests/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/requests/sessions.py b/src/requests/sessions.py index 578cc44d5c..62b41f265d 100644 --- a/src/requests/sessions.py +++ b/src/requests/sessions.py @@ -6,16 +6,21 @@ requests (cookies, auth, proxies). """ +from __future__ import annotations + import os import sys import time from collections import OrderedDict +from collections.abc import Generator, Mapping, MutableMapping from datetime import timedelta +from typing import TYPE_CHECKING, Any, cast from ._internal_utils import to_native_string +from ._types import is_prepared from .adapters import HTTPAdapter -from .auth import _basic_auth_str -from .compat import Mapping, cookielib, urljoin, urlparse +from .auth import _basic_auth_str # type: ignore[reportPrivateUsage] +from .compat import cookielib, urljoin, urlparse from .cookies import ( RequestsCookieJar, cookiejar_from_dict, @@ -33,9 +38,10 @@ # formerly defined here, reexposed here for backward compatibility from .models import ( # noqa: F401 DEFAULT_REDIRECT_LIMIT, - REDIRECT_STATI, + REDIRECT_STATI, # type: ignore[reportUnusedImport] PreparedRequest, Request, + Response, ) from .status_codes import codes from .structures import CaseInsensitiveDict @@ -48,10 +54,27 @@ requote_uri, resolve_proxies, rewind_body, - should_bypass_proxies, + should_bypass_proxies, # type: ignore[reportUnusedImport] # re-export for external consumers to_key_val_list, ) +if TYPE_CHECKING: + from http.cookiejar import CookieJar + + from ._types import ( + AuthType, + CertType, + DataType, + FilesType, + HooksType, + HookType, + JsonType, + ParamsType, + TimeoutType, + VerifyType, + ) + from .adapters import BaseAdapter + # Preferred clock, based on which one is more accurate on a given system. if sys.platform == "win32": preferred_clock = time.perf_counter @@ -59,7 +82,9 @@ preferred_clock = time.time -def merge_setting(request_setting, session_setting, dict_class=OrderedDict): +def merge_setting( + request_setting: Any, session_setting: Any, dict_class: type = OrderedDict +) -> Any: """Determines appropriate setting for a given request, taking into account the explicit setting on that request, and the setting in the session. If a setting is a dictionary, they will be merged together using `dict_class` @@ -77,8 +102,8 @@ def merge_setting(request_setting, session_setting, dict_class=OrderedDict): ): return request_setting - merged_setting = dict_class(to_key_val_list(session_setting)) - merged_setting.update(to_key_val_list(request_setting)) + merged_setting = dict_class(to_key_val_list(session_setting)) # type: ignore[arg-type] # isinstance narrows Any to Mapping[Unknown] + merged_setting.update(to_key_val_list(request_setting)) # type: ignore[arg-type] # Remove keys that are set to None. Extract keys first to avoid altering # the dictionary during iteration. @@ -89,7 +114,9 @@ def merge_setting(request_setting, session_setting, dict_class=OrderedDict): return merged_setting -def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict): +def merge_hooks( + request_hooks: HooksType, session_hooks: HooksType, dict_class: type = OrderedDict +) -> HooksType: """Properly merges both requests and session hooks. This is necessary because when request_hooks == {'response': []}, the @@ -105,7 +132,13 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict): class SessionRedirectMixin: - def get_redirect_target(self, resp): + max_redirects: int + trust_env: bool + cookies: RequestsCookieJar + + def send(self, request: PreparedRequest, **kwargs: Any) -> Response: ... + + def get_redirect_target(self, resp: Response) -> str | None: """Receives a Response. Returns a redirect URI or ``None``""" # Due to the nature of how requests processes redirects this method will # be called at least once upon the original response and at least twice @@ -125,7 +158,7 @@ def get_redirect_target(self, resp): return to_native_string(location, "utf8") return None - def should_strip_auth(self, old_url, new_url): + def should_strip_auth(self, old_url: str, new_url: str) -> bool: """Decide whether Authorization header should be removed when redirecting""" old_parsed = urlparse(old_url) new_parsed = urlparse(new_url) @@ -159,19 +192,19 @@ def should_strip_auth(self, old_url, new_url): def resolve_redirects( self, - resp, - req, - stream=False, - timeout=None, - verify=True, - cert=None, - proxies=None, - yield_requests=False, - **adapter_kwargs, - ): + resp: Response, + req: PreparedRequest, + stream: bool = False, + timeout: TimeoutType = None, + verify: VerifyType = True, + cert: CertType = None, + proxies: dict[str, str] | None = None, + yield_requests: bool = False, + **adapter_kwargs: Any, + ) -> Generator[Response, None, None]: """Receives a Response. Returns a generator of Responses or Requests.""" - hist = [] # keep track of history + hist: list[Response] = [] # keep track of history url = self.get_redirect_target(resp) previous_fragment = urlparse(req.url).fragment @@ -238,9 +271,10 @@ def resolve_redirects( # Extract any cookies sent on the response to the cookiejar # in the new request. Because we've mutated our copied prepared # request, use the old one that we haven't yet touched. - extract_cookies_to_jar(prepared_request._cookies, req, resp.raw) - merge_cookies(prepared_request._cookies, self.cookies) - prepared_request.prepare_cookies(prepared_request._cookies) + cookie_jar = cast("CookieJar", prepared_request._cookies) # type: ignore[reportPrivateUsage] + extract_cookies_to_jar(cookie_jar, req, resp.raw) + merge_cookies(cookie_jar, self.cookies) + prepared_request.prepare_cookies(cookie_jar) # Rebuild auth and proxy information. proxies = self.rebuild_proxies(prepared_request, proxies) @@ -249,7 +283,7 @@ def resolve_redirects( # A failed tell() sets `_body_position` to `object()`. This non-None # value ensures `rewindable` will be True, allowing us to raise an # UnrewindableBodyError, instead of hanging the connection. - rewindable = prepared_request._body_position is not None and ( + rewindable = prepared_request._body_position is not None and ( # type: ignore[reportPrivateUsage] "Content-Length" in headers or "Transfer-Encoding" in headers ) @@ -261,7 +295,7 @@ def resolve_redirects( req = prepared_request if yield_requests: - yield req + yield req # type: ignore[misc] # Internal use only, returns PreparedRequest else: resp = self.send( req, @@ -280,17 +314,22 @@ def resolve_redirects( url = self.get_redirect_target(resp) yield resp - def rebuild_auth(self, prepared_request, response): + def rebuild_auth( + self, prepared_request: PreparedRequest, response: Response + ) -> None: """When being redirected we may want to strip authentication from the request to avoid leaking credentials. This method intelligently removes and reapplies authentication where possible to avoid credential loss. """ + original_request = response.request + assert is_prepared(original_request) + assert is_prepared(prepared_request) + headers = prepared_request.headers + original_url = original_request.url url = prepared_request.url - if "Authorization" in headers and self.should_strip_auth( - response.request.url, url - ): + if "Authorization" in headers and self.should_strip_auth(original_url, url): # If we get redirected to a new host, we should strip out any # authentication headers. del headers["Authorization"] @@ -300,7 +339,11 @@ def rebuild_auth(self, prepared_request, response): if new_auth is not None: prepared_request.prepare_auth(new_auth) - def rebuild_proxies(self, prepared_request, proxies): + def rebuild_proxies( + self, + prepared_request: PreparedRequest, + proxies: dict[str, str] | None, + ) -> dict[str, str]: """This method re-evaluates the proxy configuration by considering the environment variables. If we are redirected to a URL covered by NO_PROXY, we strip the proxy configuration. Otherwise, we set missing @@ -320,18 +363,20 @@ def rebuild_proxies(self, prepared_request, proxies): del headers["Proxy-Authorization"] try: - username, password = get_auth_from_url(new_proxies[scheme]) + username, password = get_auth_from_url(new_proxies[scheme]) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling except KeyError: username, password = None, None # urllib3 handles proxy authorization for us in the standard adapter. # Avoid appending this to TLS tunneled requests where it may be leaked. - if not scheme.startswith("https") and username and password: + if not scheme.startswith("https") and username and password: # type: ignore[arg-type] # TODO(typing): str|bytes URL handling headers["Proxy-Authorization"] = _basic_auth_str(username, password) return new_proxies - def rebuild_method(self, prepared_request, response): + def rebuild_method( + self, prepared_request: PreparedRequest, response: Response + ) -> None: """When being redirected we may want to change the method of the request based on certain specs or browser behavior. """ @@ -373,7 +418,20 @@ class Session(SessionRedirectMixin): """ - __attrs__ = [ + headers: CaseInsensitiveDict[str] + auth: AuthType + proxies: dict[str, str] + hooks: dict[str, list[HookType]] + params: MutableMapping[str, Any] + stream: bool + verify: VerifyType + cert: CertType + max_redirects: int + trust_env: bool + cookies: RequestsCookieJar + adapters: MutableMapping[str, BaseAdapter] + + __attrs__: list[str] = [ "headers", "cookies", "auth", @@ -388,7 +446,7 @@ class Session(SessionRedirectMixin): "max_redirects", ] - def __init__(self): + def __init__(self) -> None: #: A case-insensitive dictionary of headers to be sent on each #: :class:`Request ` sent from this #: :class:`Session `. @@ -451,13 +509,13 @@ def __init__(self): self.mount("https://", HTTPAdapter()) self.mount("http://", HTTPAdapter()) - def __enter__(self): + def __enter__(self) -> Session: return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.close() - def prepare_request(self, request): + def prepare_request(self, request: Request) -> PreparedRequest: """Constructs a :class:`PreparedRequest ` for transmission and returns it. The :class:`PreparedRequest` has settings merged from the :class:`Request ` instance and those of the @@ -467,6 +525,9 @@ def prepare_request(self, request): session's settings. :rtype: requests.PreparedRequest """ + url = cast(str, request.url) + method = cast(str, request.method) + cookies = request.cookies or {} # Bootstrap CookieJar. @@ -481,12 +542,12 @@ def prepare_request(self, request): # Set environment's basic authentication if not explicitly set. auth = request.auth if self.trust_env and not auth and not self.auth: - auth = get_netrc_auth(request.url) + auth = get_netrc_auth(url) p = PreparedRequest() p.prepare( - method=request.method.upper(), - url=request.url, + method=method.upper(), + url=url, files=request.files, data=request.data, json=request.json, @@ -502,23 +563,23 @@ def prepare_request(self, request): def request( self, - method, - url, - params=None, - data=None, - headers=None, - cookies=None, - files=None, - auth=None, - timeout=None, - allow_redirects=True, - proxies=None, - hooks=None, - stream=None, - verify=None, - cert=None, - json=None, - ): + method: str, + url: str, + params: ParamsType = None, + data: DataType = None, + headers: Mapping[str, str | bytes] | None = None, + cookies: RequestsCookieJar | CookieJar | dict[str, str] | None = None, + files: FilesType = None, + auth: AuthType = None, + timeout: TimeoutType = None, + allow_redirects: bool = True, + proxies: dict[str, str] | None = None, + hooks: HooksType = None, + stream: bool | None = None, + verify: VerifyType | None = None, + cert: CertType = None, + json: JsonType = None, + ) -> Response: """Constructs a :class:`Request `, prepares it and sends it. Returns :class:`Response ` object. @@ -577,6 +638,8 @@ def request( ) prep = self.prepare_request(req) + assert is_prepared(prep) + proxies = proxies or {} settings = self.merge_environment_settings( @@ -593,7 +656,7 @@ def request( return resp - def get(self, url, **kwargs): + def get(self, url: str, **kwargs: Any) -> Response: r"""Sends a GET request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -604,7 +667,7 @@ def get(self, url, **kwargs): kwargs.setdefault("allow_redirects", True) return self.request("GET", url, **kwargs) - def options(self, url, **kwargs): + def options(self, url: str, **kwargs: Any) -> Response: r"""Sends a OPTIONS request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -615,7 +678,7 @@ def options(self, url, **kwargs): kwargs.setdefault("allow_redirects", True) return self.request("OPTIONS", url, **kwargs) - def head(self, url, **kwargs): + def head(self, url: str, **kwargs: Any) -> Response: r"""Sends a HEAD request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -626,7 +689,9 @@ def head(self, url, **kwargs): kwargs.setdefault("allow_redirects", False) return self.request("HEAD", url, **kwargs) - def post(self, url, data=None, json=None, **kwargs): + def post( + self, url: str, data: DataType = None, json: JsonType = None, **kwargs: Any + ) -> Response: r"""Sends a POST request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -639,7 +704,7 @@ def post(self, url, data=None, json=None, **kwargs): return self.request("POST", url, data=data, json=json, **kwargs) - def put(self, url, data=None, **kwargs): + def put(self, url: str, data: DataType = None, **kwargs: Any) -> Response: r"""Sends a PUT request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -651,7 +716,7 @@ def put(self, url, data=None, **kwargs): return self.request("PUT", url, data=data, **kwargs) - def patch(self, url, data=None, **kwargs): + def patch(self, url: str, data: DataType = None, **kwargs: Any) -> Response: r"""Sends a PATCH request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -663,7 +728,7 @@ def patch(self, url, data=None, **kwargs): return self.request("PATCH", url, data=data, **kwargs) - def delete(self, url, **kwargs): + def delete(self, url: str, **kwargs: Any) -> Response: r"""Sends a DELETE request. Returns :class:`Response` object. :param url: URL for the new :class:`Request` object. @@ -673,7 +738,7 @@ def delete(self, url, **kwargs): return self.request("DELETE", url, **kwargs) - def send(self, request, **kwargs): + def send(self, request: PreparedRequest, **kwargs: Any) -> Response: """Send a given PreparedRequest. :rtype: requests.Response @@ -691,6 +756,8 @@ def send(self, request, **kwargs): if isinstance(request, Request): raise ValueError("You can only send PreparedRequests.") + assert is_prepared(request) + # Set up variables needed for resolve_redirects and dispatching of hooks allow_redirects = kwargs.pop("allow_redirects", True) stream = kwargs.get("stream") @@ -739,7 +806,7 @@ def send(self, request, **kwargs): # If redirects aren't being followed, store the response on the Request for Response.next(). if not allow_redirects: try: - r._next = next( + r._next = next( # type: ignore[assignment] # yield_requests=True returns PreparedRequest self.resolve_redirects(r, request, yield_requests=True, **kwargs) ) except StopIteration: @@ -750,7 +817,14 @@ def send(self, request, **kwargs): return r - def merge_environment_settings(self, url, proxies, stream, verify, cert): + def merge_environment_settings( + self, + url: str, + proxies: dict[str, str] | None, + stream: bool | None, + verify: VerifyType | None, + cert: CertType, + ) -> dict[str, Any]: """ Check the environment and merge it with some settings. @@ -761,8 +835,9 @@ def merge_environment_settings(self, url, proxies, stream, verify, cert): # Set environment's proxies. no_proxy = proxies.get("no_proxy") if proxies is not None else None env_proxies = get_environ_proxies(url, no_proxy=no_proxy) - for k, v in env_proxies.items(): - proxies.setdefault(k, v) + if proxies is not None: + for k, v in env_proxies.items(): + proxies.setdefault(k, v) # Look for requests environment configuration # and be compatible with cURL. @@ -781,7 +856,7 @@ def merge_environment_settings(self, url, proxies, stream, verify, cert): return {"proxies": proxies, "stream": stream, "verify": verify, "cert": cert} - def get_adapter(self, url): + def get_adapter(self, url: str) -> BaseAdapter: """ Returns the appropriate connection adapter for the given URL. @@ -794,12 +869,12 @@ def get_adapter(self, url): # Nothing matches :-/ raise InvalidSchema(f"No connection adapters were found for {url!r}") - def close(self): + def close(self) -> None: """Closes all adapters and as such the session""" for v in self.adapters.values(): v.close() - def mount(self, prefix, adapter): + def mount(self, prefix: str, adapter: BaseAdapter) -> None: """Registers a connection adapter to a prefix. Adapters are sorted in descending order by prefix length. @@ -810,16 +885,16 @@ def mount(self, prefix, adapter): for key in keys_to_move: self.adapters[key] = self.adapters.pop(key) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = {attr: getattr(self, attr, None) for attr in self.__attrs__} return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: for attr, value in state.items(): setattr(self, attr, value) -def session(): +def session() -> Session: """ Returns a :class:`Session` for context-management. diff --git a/src/requests/status_codes.py b/src/requests/status_codes.py index c7945a2f06..6c59d6baec 100644 --- a/src/requests/status_codes.py +++ b/src/requests/status_codes.py @@ -103,7 +103,7 @@ 511: ("network_authentication_required", "network_auth", "network_authentication"), } -codes = LookupDict(name="status_codes") +codes: LookupDict[int] = LookupDict(name="status_codes") def _init(): @@ -113,7 +113,7 @@ def _init(): if not title.startswith(("\\", "/")): setattr(codes, title.upper(), code) - def doc(code): + def doc(code: int) -> str: names = ", ".join(f"``{n}``" for n in _codes[code]) return "* %d: %s" % (code, names) diff --git a/src/requests/structures.py b/src/requests/structures.py index 188e13e482..04fd300235 100644 --- a/src/requests/structures.py +++ b/src/requests/structures.py @@ -5,12 +5,19 @@ Data structures that power Requests. """ +from __future__ import annotations + from collections import OrderedDict +from collections.abc import Iterable, Iterator, Mapping +from typing import Any, Generic, TypeVar, overload + +from .compat import MutableMapping -from .compat import Mapping, MutableMapping +_VT = TypeVar("_VT") +_D = TypeVar("_D") -class CaseInsensitiveDict(MutableMapping): +class CaseInsensitiveDict(MutableMapping[str, _VT], Generic[_VT]): """A case-insensitive ``dict``-like object. Implements all methods and operations of @@ -37,63 +44,81 @@ class CaseInsensitiveDict(MutableMapping): behavior is undefined. """ - def __init__(self, data=None, **kwargs): + _store: OrderedDict[str, tuple[str, _VT]] + + def __init__( + self, + data: Mapping[str, _VT] | Iterable[tuple[str, _VT]] | None = None, + **kwargs: _VT, + ) -> None: self._store = OrderedDict() if data is None: data = {} self.update(data, **kwargs) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: _VT) -> None: # Use the lowercased key for lookups, but store the actual # key alongside the value. self._store[key.lower()] = (key, value) - def __getitem__(self, key): + def __getitem__(self, key: str) -> _VT: return self._store[key.lower()][1] - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._store[key.lower()] - def __iter__(self): - return (casedkey for casedkey, mappedvalue in self._store.values()) + def __iter__(self) -> Iterator[str]: + return (casedkey for casedkey, _ in self._store.values()) - def __len__(self): + def __len__(self) -> int: return len(self._store) - def lower_items(self): + def lower_items(self) -> Iterator[tuple[str, _VT]]: """Like iteritems(), but with all lowercase keys.""" return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Mapping): - other = CaseInsensitiveDict(other) + other_dict: CaseInsensitiveDict[Any] = CaseInsensitiveDict(other) # type: ignore[reportUnknownArgumentType] else: return NotImplemented # Compare insensitively - return dict(self.lower_items()) == dict(other.lower_items()) + return dict(self.lower_items()) == dict(other_dict.lower_items()) # Copy is required - def copy(self): + def copy(self) -> CaseInsensitiveDict[_VT]: return CaseInsensitiveDict(self._store.values()) - def __repr__(self): + def __repr__(self) -> str: return str(dict(self.items())) -class LookupDict(dict): +class LookupDict(dict[str, _VT]): """Dictionary lookup object.""" - def __init__(self, name=None): + name: Any + + def __init__(self, name: Any = None) -> None: self.name = name super().__init__() - def __repr__(self): + def __repr__(self) -> str: return f"" - def __getitem__(self, key): + def __getattr__(self, key: str) -> _VT | None: + # Allow attribute-style access to values + return self.__dict__.get(key, None) + + def __getitem__(self, key: str) -> _VT | None: # type: ignore[override] # We allow fall-through here, so values default to None return self.__dict__.get(key, None) - def get(self, key, default=None): + @overload + def get(self, key: str, default: None = None) -> _VT | None: ... + + @overload + def get(self, key: str, default: _D | _VT) -> _D | _VT: ... + + def get(self, key: str, default: _D | None = None) -> _VT | _D | None: return self.__dict__.get(key, default) diff --git a/src/requests/utils.py b/src/requests/utils.py index d113a6ff3e..d528b731b8 100644 --- a/src/requests/utils.py +++ b/src/requests/utils.py @@ -6,6 +6,8 @@ that are also useful for external consumption. """ +from __future__ import annotations + import codecs import contextlib import io @@ -18,6 +20,16 @@ import warnings import zipfile from collections import OrderedDict +from collections.abc import Generator, Iterable, Iterator +from typing import ( + TYPE_CHECKING, + Any, + AnyStr, + Final, + TypeVar, + cast, + overload, +) from urllib3.util import make_headers, parse_url @@ -26,11 +38,12 @@ # to_native_string is unused here, but imported here for backwards compatibility from ._internal_utils import ( # noqa: F401 - _HEADER_VALIDATORS_BYTE, - _HEADER_VALIDATORS_STR, - HEADER_VALIDATORS, - to_native_string, + _HEADER_VALIDATORS_BYTE, # type: ignore[reportPrivateUsage] + _HEADER_VALIDATORS_STR, # type: ignore[reportPrivateUsage] + HEADER_VALIDATORS, # type: ignore[reportUnusedImport] + to_native_string, # type: ignore[reportUnusedImport] ) +from ._types import SupportsItems from .compat import ( Mapping, basestring, @@ -40,7 +53,7 @@ integer_types, is_urllib3_1, proxy_bypass, - proxy_bypass_environment, + proxy_bypass_environment, # type: ignore[attr-defined] # https://github.com/python/cpython/issues/145331 quote, str, unquote, @@ -57,14 +70,24 @@ ) from .structures import CaseInsensitiveDict -NETRC_FILES = (".netrc", "_netrc") +if TYPE_CHECKING: + from http.cookiejar import CookieJar + from io import BufferedWriter + + from ._types import SupportsItems, UriType + from .models import PreparedRequest, Request, Response + +NETRC_FILES: Final = (".netrc", "_netrc") + +DEFAULT_CA_BUNDLE_PATH: str = certs.where() -DEFAULT_CA_BUNDLE_PATH = certs.where() +DEFAULT_PORTS: Final = {"http": 80, "https": 443} -DEFAULT_PORTS = {"http": 80, "https": 443} +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") # Ensure that ', ' is used to preserve previous delimiter behavior. -DEFAULT_ACCEPT_ENCODING = ", ".join( +DEFAULT_ACCEPT_ENCODING: Final = ", ".join( re.split(r",\s*", make_headers(accept_encoding=True)["accept-encoding"]) ) @@ -72,7 +95,7 @@ if sys.platform == "win32": # provide a proxy_bypass version on Windows without DNS lookups - def proxy_bypass_registry(host): + def proxy_bypass_registry(host: str) -> bool: try: import winreg except ImportError: @@ -110,7 +133,7 @@ def proxy_bypass_registry(host): return True return False - def proxy_bypass(host): # noqa + def proxy_bypass(host: str) -> bool: # noqa """Return True, if the host should be bypassed. Checks proxy settings gathered from the environment, if specified, @@ -122,16 +145,18 @@ def proxy_bypass(host): # noqa return proxy_bypass_registry(host) -def dict_to_sequence(d): +def dict_to_sequence( + d: SupportsItems | Iterable[tuple[Any, Any]], +) -> Iterable[tuple[Any, Any]]: """Returns an internal sequence dictionary update.""" - if hasattr(d, "items"): - d = d.items() + if isinstance(d, SupportsItems): + return d.items() return d -def super_len(o): +def super_len(o: Any) -> int: total_length = None current_position = 0 @@ -202,7 +227,7 @@ def super_len(o): return max(0, total_length - current_position) -def get_netrc_auth(url, raise_errors=False): +def get_netrc_auth(url: UriType, raise_errors: bool = False) -> tuple[str, str] | None: """Returns the Requests tuple auth for a given url from netrc.""" netrc_file = os.environ.get("NETRC") @@ -230,11 +255,11 @@ def get_netrc_auth(url, raise_errors=False): host = ri.hostname try: - _netrc = netrc(netrc_path).authenticators(host) + _netrc = netrc(netrc_path).authenticators(host) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling if _netrc and any(_netrc): # Return with login / password login_i = 0 if _netrc[0] else 1 - return (_netrc[login_i], _netrc[2]) + return (_netrc[login_i] or "", _netrc[2] or "") except (NetrcParseError, OSError): # If there was a parsing error or a permissions issue reading the file, # we'll just skip netrc auth unless explicitly asked to raise errors. @@ -246,14 +271,14 @@ def get_netrc_auth(url, raise_errors=False): pass -def guess_filename(obj): +def guess_filename(obj: Any) -> str | None: """Tries to guess the filename of the given object.""" name = getattr(obj, "name", None) if name and isinstance(name, basestring) and name[0] != "<" and name[-1] != ">": - return os.path.basename(name) + return os.path.basename(name) # type: ignore[return-value] # TODO(typing): str|bytes URL handling -def extract_zipped_paths(path): +def extract_zipped_paths(path: str) -> str: """Replace nonexistent paths that look like they refer to a member of a zip archive with the location of an extracted copy of the target, or else just return the provided path unchanged. @@ -291,7 +316,7 @@ def extract_zipped_paths(path): @contextlib.contextmanager -def atomic_open(filename): +def atomic_open(filename: str) -> Iterator[BufferedWriter]: """Write a file to the disk in an atomic fashion""" tmp_descriptor, tmp_name = tempfile.mkstemp(dir=os.path.dirname(filename)) try: @@ -303,7 +328,9 @@ def atomic_open(filename): raise -def from_key_val_list(value): +def from_key_val_list( + value: Mapping[Any, Any] | Iterable[tuple[Any, Any]] | None, +) -> dict[Any, Any] | None: """Take an object and test to see if it can be represented as a dictionary. Unless it can not be represented as such, return an OrderedDict, e.g., @@ -330,7 +357,15 @@ def from_key_val_list(value): return OrderedDict(value) -def to_key_val_list(value): +@overload +def to_key_val_list(value: None) -> None: ... +@overload +def to_key_val_list( + value: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]], +) -> list[tuple[_KT, _VT]]: ... +def to_key_val_list( + value: Mapping[_KT, _VT] | Iterable[tuple[_KT, _VT]] | None, +) -> list[tuple[_KT, _VT]] | None: """Take an object and test to see if it can be represented as a dictionary. If it can be, return a list of tuples, e.g., @@ -353,14 +388,14 @@ def to_key_val_list(value): if isinstance(value, (str, bytes, bool, int)): raise ValueError("cannot encode objects that are not 2-tuples") - if isinstance(value, Mapping): - value = value.items() + if isinstance(value, SupportsItems): + return list(value.items()) return list(value) # From mitsuhiko/werkzeug (used with permission). -def parse_list_header(value): +def parse_list_header(value: str) -> list[str]: """Parse lists as described by RFC 2068 Section 2. In particular, parse comma-separated lists where the elements of @@ -383,7 +418,7 @@ def parse_list_header(value): :return: :class:`list` :rtype: list """ - result = [] + result: list[str] = [] for item in _parse_list_header(value): if item[:1] == item[-1:] == '"': item = unquote_header_value(item[1:-1]) @@ -392,7 +427,7 @@ def parse_list_header(value): # From mitsuhiko/werkzeug (used with permission). -def parse_dict_header(value): +def parse_dict_header(value: str) -> dict[str, str | None]: """Parse lists of key, value pairs as described by RFC 2068 Section 2 and convert them into a python dict: @@ -414,7 +449,7 @@ def parse_dict_header(value): :return: :class:`dict` :rtype: dict """ - result = {} + result: dict[str, str | None] = {} for item in _parse_list_header(value): if "=" not in item: result[item] = None @@ -427,7 +462,7 @@ def parse_dict_header(value): # From mitsuhiko/werkzeug (used with permission). -def unquote_header_value(value, is_filename=False): +def unquote_header_value(value: str, is_filename: bool = False) -> str: r"""Unquotes a header value. (Reversal of :func:`quote_header_value`). This does not use the real unquoting but what browsers are actually using for quoting. @@ -452,7 +487,7 @@ def unquote_header_value(value, is_filename=False): return value -def dict_from_cookiejar(cj): +def dict_from_cookiejar(cj: CookieJar) -> dict[str, str | None]: """Returns a key/value dictionary from a CookieJar. :param cj: CookieJar object to extract cookies from. @@ -463,7 +498,7 @@ def dict_from_cookiejar(cj): return cookie_dict -def add_dict_to_cookiejar(cj, cookie_dict): +def add_dict_to_cookiejar(cj: CookieJar, cookie_dict: dict[str, str]) -> CookieJar: """Returns a CookieJar from a key/value dictionary. :param cj: CookieJar to insert cookies into. @@ -474,7 +509,7 @@ def add_dict_to_cookiejar(cj, cookie_dict): return cookiejar_from_dict(cookie_dict, cj) -def get_encodings_from_content(content): +def get_encodings_from_content(content: str) -> list[str]: """Returns encodings from given content string. :param content: bytestring to extract encodings from. @@ -499,7 +534,7 @@ def get_encodings_from_content(content): ) -def _parse_content_type_header(header): +def _parse_content_type_header(header: str) -> tuple[str, dict[str, Any]]: """Returns content type and parameters from given header :param header: string @@ -509,7 +544,7 @@ def _parse_content_type_header(header): tokens = header.split(";") content_type, params = tokens[0].strip(), tokens[1:] - params_dict = {} + params_dict: dict[str, str | bool] = {} items_to_strip = "\"' " for param in params: @@ -524,7 +559,7 @@ def _parse_content_type_header(header): return content_type, params_dict -def get_encoding_from_headers(headers): +def get_encoding_from_headers(headers: CaseInsensitiveDict[str]) -> str | None: """Returns encodings from given HTTP Header Dict. :param headers: dictionary to extract encoding from. @@ -549,7 +584,9 @@ def get_encoding_from_headers(headers): return "utf-8" -def stream_decode_response_unicode(iterator, r): +def stream_decode_response_unicode( + iterator: Iterable[bytes], r: Response +) -> Generator[str | bytes, None, None]: """Stream decodes an iterator.""" if r.encoding is None: @@ -566,7 +603,17 @@ def stream_decode_response_unicode(iterator, r): yield rv -def iter_slices(string, slice_length): +@overload +def iter_slices( + string: bytes, slice_length: int | None +) -> Generator[bytes, None, None]: ... +@overload +def iter_slices( + string: str, slice_length: int | None +) -> Generator[str, None, None]: ... +def iter_slices( + string: bytes | str, slice_length: int | None +) -> Generator[bytes | str, None, None]: """Iterate over slices of a string.""" pos = 0 if slice_length is None or slice_length <= 0: @@ -576,7 +623,7 @@ def iter_slices(string, slice_length): pos += slice_length -def get_unicode_from_response(r): +def get_unicode_from_response(r: Response) -> str | bytes | None: """Returns the requested content back in unicode. :param r: Response object to get unicode content from. @@ -597,31 +644,31 @@ def get_unicode_from_response(r): DeprecationWarning, ) - tried_encodings = [] + tried_encodings: list[str] = [] # Try charset from content-type encoding = get_encoding_from_headers(r.headers) if encoding: try: - return str(r.content, encoding) + return str(r.content, encoding) # type: ignore[arg-type] except UnicodeError: tried_encodings.append(encoding) # Fall back: try: - return str(r.content, encoding, errors="replace") + return str(r.content, encoding or "utf-8", errors="replace") # type: ignore[arg-type] except TypeError: return r.content # The unreserved URI characters (RFC 3986) -UNRESERVED_SET = frozenset( +UNRESERVED_SET: Final = frozenset( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~" ) -def unquote_unreserved(uri): +def unquote_unreserved(uri: str) -> str: """Un-escape any percent-escape sequences in a URI that are unreserved characters. This leaves all reserved, illegal and non-ASCII bytes encoded. @@ -645,7 +692,7 @@ def unquote_unreserved(uri): return "".join(parts) -def requote_uri(uri): +def requote_uri(uri: str) -> str: """Re-quote the given URI. This function passes the given URI through an unquote/quote cycle to @@ -667,7 +714,7 @@ def requote_uri(uri): return quote(uri, safe=safe_without_percent) -def address_in_network(ip, net): +def address_in_network(ip: str, net: str) -> bool: """This function allows you to check if an IP belongs to a network subnet Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24 @@ -682,7 +729,7 @@ def address_in_network(ip, net): return (ipaddr & netmask) == (network & netmask) -def dotted_netmask(mask): +def dotted_netmask(mask: int) -> str: """Converts mask from /xx format to xxx.xxx.xxx.xxx Example: if mask is 24 function returns 255.255.255.0 @@ -693,7 +740,7 @@ def dotted_netmask(mask): return socket.inet_ntoa(struct.pack(">I", bits)) -def is_ipv4_address(string_ip): +def is_ipv4_address(string_ip: str) -> bool: """ :rtype: bool """ @@ -704,7 +751,7 @@ def is_ipv4_address(string_ip): return True -def is_valid_cidr(string_network): +def is_valid_cidr(string_network: str) -> bool: """ Very simple check of the cidr format in no_proxy variable. @@ -729,7 +776,7 @@ def is_valid_cidr(string_network): @contextlib.contextmanager -def set_environ(env_name, value): +def set_environ(env_name: str, value: str | None) -> Iterator[None]: """Set the environment variable 'env_name' to 'value' Save previous value, yield, and then restore the previous value stored in @@ -737,6 +784,7 @@ def set_environ(env_name, value): If 'value' is None, do nothing""" value_changed = value is not None + old_value: str | None = None if value_changed: old_value = os.environ.get(env_name) os.environ[env_name] = value @@ -750,7 +798,7 @@ def set_environ(env_name, value): os.environ[env_name] = old_value -def should_bypass_proxies(url, no_proxy): +def should_bypass_proxies(url: UriType, no_proxy: str | None) -> bool: """ Returns whether we should bypass proxies or not. @@ -759,7 +807,7 @@ def should_bypass_proxies(url, no_proxy): # Prioritize lowercase environment variables over uppercase # to keep a consistent behaviour with other http projects (curl, wget). - def get_proxy(key): + def get_proxy(key: str) -> str | None: return os.environ.get(key) or os.environ.get(key.upper()) # First check whether no_proxy is defined. If it is, check that the URL @@ -776,12 +824,12 @@ def get_proxy(key): if no_proxy: # We need to check whether we match here. We need to see if we match # the end of the hostname, both with and without the port. - no_proxy = (host for host in no_proxy.replace(" ", "").split(",") if host) + no_proxy_hosts = (host for host in no_proxy.replace(" ", "").split(",") if host) - if is_ipv4_address(parsed.hostname): - for proxy_ip in no_proxy: + if is_ipv4_address(parsed.hostname): # type: ignore[arg-type] # TODO(typing): str|bytes URL handling + for proxy_ip in no_proxy_hosts: if is_valid_cidr(proxy_ip): - if address_in_network(parsed.hostname, proxy_ip): + if address_in_network(parsed.hostname, proxy_ip): # type: ignore[arg-type] # TODO(typing): str|bytes URL handling return True elif parsed.hostname == proxy_ip: # If no_proxy ip was defined in plain IP notation instead of cidr notation & @@ -790,10 +838,10 @@ def get_proxy(key): else: host_with_port = parsed.hostname if parsed.port: - host_with_port += f":{parsed.port}" + host_with_port += f":{parsed.port}" # type: ignore[operator] # TODO(typing): str|bytes URL handling - for host in no_proxy: - if parsed.hostname.endswith(host) or host_with_port.endswith(host): + for host in no_proxy_hosts: + if parsed.hostname.endswith(host) or host_with_port.endswith(host): # type: ignore[arg-type] # TODO(typing): str|bytes URL handling # The URL does match something in no_proxy, so we don't want # to apply the proxies on this URL. return True @@ -801,7 +849,7 @@ def get_proxy(key): with set_environ("no_proxy", no_proxy_arg): # parsed.hostname can be `None` in cases such as a file URI. try: - bypass = proxy_bypass(parsed.hostname) + bypass = proxy_bypass(parsed.hostname) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling except (TypeError, socket.gaierror): bypass = False @@ -811,7 +859,7 @@ def get_proxy(key): return False -def get_environ_proxies(url, no_proxy=None): +def get_environ_proxies(url: UriType, no_proxy: str | None = None) -> dict[str, str]: """ Return a dict of environment proxies. @@ -823,7 +871,7 @@ def get_environ_proxies(url, no_proxy=None): return getproxies() -def select_proxy(url, proxies): +def select_proxy(url: str, proxies: dict[str, str] | None) -> str | None: """Select a proxy for the url, if applicable. :param url: The url being for the request @@ -849,7 +897,11 @@ def select_proxy(url, proxies): return proxy -def resolve_proxies(request, proxies, trust_env=True): +def resolve_proxies( + request: Request | PreparedRequest, + proxies: dict[str, str] | None, + trust_env: bool = True, +) -> dict[str, str]: """This method takes proxy information from a request and configuration input to resolve a mapping of target proxies. This will consider settings such as NO_PROXY to strip proxy configurations. @@ -861,7 +913,7 @@ def resolve_proxies(request, proxies, trust_env=True): :rtype: dict """ proxies = proxies if proxies is not None else {} - url = request.url + url = cast(str, request.url) scheme = urlparse(url).scheme no_proxy = proxies.get("no_proxy") new_proxies = proxies.copy() @@ -869,14 +921,14 @@ def resolve_proxies(request, proxies, trust_env=True): if trust_env and not should_bypass_proxies(url, no_proxy=no_proxy): environ_proxies = get_environ_proxies(url, no_proxy=no_proxy) - proxy = environ_proxies.get(scheme, environ_proxies.get("all")) + proxy = environ_proxies.get(scheme, environ_proxies.get("all")) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling if proxy: - new_proxies.setdefault(scheme, proxy) + new_proxies.setdefault(scheme, proxy) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling return new_proxies -def default_user_agent(name="python-requests"): +def default_user_agent(name: str = "python-requests") -> str: """ Return a string representing the default user agent. @@ -885,7 +937,7 @@ def default_user_agent(name="python-requests"): return f"{name}/{__version__}" -def default_headers(): +def default_headers() -> CaseInsensitiveDict[str]: """ :rtype: requests.structures.CaseInsensitiveDict """ @@ -899,7 +951,7 @@ def default_headers(): ) -def parse_header_links(value): +def parse_header_links(value: str) -> list[dict[str, str]]: """Return a list of parsed link headers proxies. i.e. Link: ; rel=front; type="image/jpeg",; rel=back;type="image/jpeg" @@ -907,7 +959,7 @@ def parse_header_links(value): :rtype: list """ - links = [] + links: list[dict[str, str]] = [] replace_chars = " '\"" @@ -921,7 +973,7 @@ def parse_header_links(value): except ValueError: url, params = val, "" - link = {"url": url.strip("<> '\"")} + link: dict[str, str] = {"url": url.strip("<> '\"")} for param in params.split(";"): try: @@ -942,7 +994,7 @@ def parse_header_links(value): _null3 = _null * 3 -def guess_json_utf(data): +def guess_json_utf(data: bytes) -> str | None: """ :rtype: str """ @@ -974,14 +1026,14 @@ def guess_json_utf(data): return None -def prepend_scheme_if_needed(url, new_scheme): +def prepend_scheme_if_needed(url: str, new_scheme: str) -> str: """Given a URL that may or may not have a scheme, prepend the given scheme. Does not replace a present scheme with the one provided as an argument. :rtype: str """ parsed = parse_url(url) - scheme, auth, host, port, path, query, fragment = parsed + scheme, auth, _host, _port, path, query, fragment = parsed # A defect in urlparse determines that there isn't a netloc present in some # urls. We previously assumed parsing was overly cautious, and swapped the @@ -994,6 +1046,7 @@ def prepend_scheme_if_needed(url, new_scheme): if auth: # parse_url doesn't provide the netloc with auth # so we'll add it ourselves. + netloc = cast(str, netloc) netloc = "@".join([auth, netloc]) if scheme is None: scheme = new_scheme @@ -1003,7 +1056,7 @@ def prepend_scheme_if_needed(url, new_scheme): return urlunparse((scheme, netloc, path, "", query, fragment)) -def get_auth_from_url(url): +def get_auth_from_url(url: UriType) -> tuple[str, str]: """Given a url with authentication components, extract them into a tuple of username,password. @@ -1012,14 +1065,14 @@ def get_auth_from_url(url): parsed = urlparse(url) try: - auth = (unquote(parsed.username), unquote(parsed.password)) + auth = (unquote(parsed.username), unquote(parsed.password)) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling except (AttributeError, TypeError): auth = ("", "") return auth -def check_header_validity(header): +def check_header_validity(header: tuple[AnyStr, AnyStr]) -> None: """Verifies that header parts don't contain leading whitespace reserved characters, or return characters. @@ -1030,10 +1083,12 @@ def check_header_validity(header): _validate_header_part(header, value, 1) -def _validate_header_part(header, header_part, header_validator_index): +def _validate_header_part( + header: tuple[AnyStr, AnyStr], header_part: AnyStr, header_validator_index: int +) -> None: if isinstance(header_part, str): validator = _HEADER_VALIDATORS_STR[header_validator_index] - elif isinstance(header_part, bytes): + elif isinstance(header_part, bytes): # type: ignore[reportUnnecessaryIsInstance] # runtime guard for non-str/bytes validator = _HEADER_VALIDATORS_BYTE[header_validator_index] else: raise InvalidHeader( @@ -1041,7 +1096,7 @@ def _validate_header_part(header, header_part, header_validator_index): f"must be of type str or bytes, not {type(header_part)}" ) - if not validator.match(header_part): + if not validator.match(header_part): # type: ignore[arg-type] header_kind = "name" if header_validator_index == 0 else "value" raise InvalidHeader( f"Invalid leading whitespace, reserved character(s), or return " @@ -1049,33 +1104,34 @@ def _validate_header_part(header, header_part, header_validator_index): ) -def urldefragauth(url): +def urldefragauth(url: UriType) -> str: """ Given a url remove the fragment and the authentication part. :rtype: str """ - scheme, netloc, path, params, query, fragment = urlparse(url) + scheme, netloc, path, params, query, _fragment = urlparse(url) # see func:`prepend_scheme_if_needed` if not netloc: netloc, path = path, netloc - netloc = netloc.rsplit("@", 1)[-1] + netloc = netloc.rsplit("@", 1)[-1] # type: ignore[arg-type] # TODO(typing): str|bytes URL handling - return urlunparse((scheme, netloc, path, params, query, "")) + return urlunparse((scheme, netloc, path, params, query, "")) # type: ignore[arg-type] # TODO(typing): str|bytes URL handling -def rewind_body(prepared_request): +def rewind_body(prepared_request: PreparedRequest) -> None: """Move file pointer back to its recorded starting position so it can be read again on redirect. """ body_seek = getattr(prepared_request.body, "seek", None) if body_seek is not None and isinstance( - prepared_request._body_position, integer_types + prepared_request._body_position, # type: ignore[reportPrivateUsage] + integer_types, ): try: - body_seek(prepared_request._body_position) + body_seek(prepared_request._body_position) # type: ignore[reportPrivateUsage] except OSError: raise UnrewindableBodyError( "An error occurred when rewinding request body for redirect." diff --git a/tests/test_requests.py b/tests/test_requests.py index 257d9d7ab1..17a108b77c 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -2566,6 +2566,7 @@ def send(self, *args, **kwargs): def build_response(self): request = self.calls[-1].args[0] r = requests.Response() + r.url = request.url try: r.status_code = int(self.redirects.pop(0))