Skip to content

Commit

Permalink
Update ClientConfig with more kwargs
Browse files Browse the repository at this point in the history
Signed-off-by: Viet Nguyen Duc <[email protected]>
  • Loading branch information
VietND96 committed Oct 24, 2024
1 parent dbe23c5 commit 6147a29
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 32 deletions.
114 changes: 112 additions & 2 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
# under the License.
import base64
import os
import socket
from typing import Optional
from urllib import parse

import certifi

from selenium.webdriver.common.proxy import Proxy
from selenium.webdriver.common.proxy import ProxyType

Expand All @@ -27,8 +30,12 @@ class ClientConfig:
def __init__(
self,
remote_server_addr: str,
keep_alive: bool = True,
proxy: Proxy = Proxy(raw={"proxyType": ProxyType.SYSTEM}),
keep_alive: Optional[bool] = True,
proxy: Optional[Proxy] = Proxy(raw={"proxyType": ProxyType.SYSTEM}),
ignore_certificates: Optional[bool] = False,
init_args_for_pool_manager: Optional[dict] = None,
timeout: Optional[int] = None,
ca_certs: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
Expand All @@ -37,17 +44,38 @@ def __init__(
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
self.proxy = proxy
self.ignore_certificates = ignore_certificates
self.init_args_for_pool_manager = init_args_for_pool_manager or {}
self.timeout = timeout
self.username = username
self.password = password
self.auth_type = auth_type
self.token = token

self.timeout = (
(
float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout())))
if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None
else socket.getdefaulttimeout()
)
if timeout is None
else timeout
)

self.ca_certs = (
(os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where())
if ca_certs is None
else ca_certs
)

@property
def remote_server_addr(self) -> str:
""":Returns: The address of the remote server."""
return self._remote_server_addr

@remote_server_addr.setter
def remote_server_addr(self, value: str) -> None:
"""Provides the address of the remote server."""
self._remote_server_addr = value

@property
Expand All @@ -73,45 +101,126 @@ def proxy(self) -> Proxy:
def proxy(self, proxy: Proxy) -> None:
"""Provides the information for communicating with the driver or
server.
For example: Proxy(raw={"proxyType": ProxyType.SYSTEM})
:Args:
- value: the proxy information to use to communicate with the driver or server
"""
self._proxy = proxy

@property
def ignore_certificates(self) -> bool:
""":Returns: The ignore certificate check value."""
return self._ignore_certificates

@ignore_certificates.setter
def ignore_certificates(self, ignore_certificates: bool) -> None:
"""Toggles the ignore certificate check.
:Args:
- value: value of ignore certificate check
"""
self._ignore_certificates = ignore_certificates

@property
def init_args_for_pool_manager(self) -> dict:
""":Returns: The dictionary of arguments will be appended while
initializing the pool manager."""
return self._init_args_for_pool_manager

@init_args_for_pool_manager.setter
def init_args_for_pool_manager(self, init_args_for_pool_manager: dict) -> None:
"""Provides dictionary of arguments will be appended while initializing the pool manager.
For example: {"init_args_for_pool_manager": {"retries": 3, "block": True}}
:Args:
- value: the dictionary of arguments will be appended while initializing the pool manager
"""
self._init_args_for_pool_manager = init_args_for_pool_manager

@property
def timeout(self) -> int:
""":Returns: The timeout (in seconds) used for communicating to the
driver/server."""
return self._timeout

@timeout.setter
def timeout(self, timeout: int) -> None:
"""Provides the timeout (in seconds) for communicating with the driver
or server.
:Args:
- value: the timeout (in seconds) to use to communicate with the driver or server
"""
self._timeout = timeout

def reset_timeout(self) -> None:
"""Resets the timeout to the default value of socket."""
self._timeout = socket.getdefaulttimeout()

@property
def ca_certs(self) -> str:
""":Returns: The path to bundle of CA certificates."""
return self._ca_certs

@ca_certs.setter
def ca_certs(self, ca_certs: str) -> None:
"""Provides the path to bundle of CA certificates for establishing
secure connections.
:Args:
- value: the path to bundle of CA certificates for establishing secure connections
"""
self._ca_certs = ca_certs

@property
def username(self) -> str:
"""Returns the username used for basic authentication to the remote
server."""
return self._username

@username.setter
def username(self, value: str) -> None:
"""Sets the username used for basic authentication to the remote
server."""
self._username = value

@property
def password(self) -> str:
"""Returns the password used for basic authentication to the remote
server."""
return self._password

@password.setter
def password(self, value: str) -> None:
"""Sets the password used for basic authentication to the remote
server."""
self._password = value

@property
def auth_type(self) -> str:
"""Returns the type of authentication to the remote server."""
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password."""
self._auth_type = value

@property
def token(self) -> str:
"""Returns the token used for authentication to the remote server."""
return self._token

@token.setter
def token(self, value: str) -> None:
"""Sets the token used for authentication to the remote server if
auth_type is not basic."""
self._token = value

def get_proxy_url(self) -> Optional[str]:
"""Returns the proxy URL to use for the connection."""
proxy_type = self.proxy.proxy_type
remote_add = parse.urlparse(self.remote_server_addr)
if proxy_type is ProxyType.DIRECT:
Expand All @@ -136,6 +245,7 @@ def get_proxy_url(self) -> Optional[str]:
return None

def get_auth_header(self) -> Optional[dict]:
"""Returns the authorization to add to the request headers."""
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
credentials = f"{self.username}:{self.password}"
Expand Down
94 changes: 69 additions & 25 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
# under the License.

import logging
import os
import platform
import socket
import string
import warnings
from base64 import b64encode
from typing import Optional
from urllib import parse

import certifi
import urllib3

from selenium import __version__
Expand Down Expand Up @@ -139,12 +136,7 @@ class RemoteConnection:
"""

browser_name = None
_timeout = (
float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout())))
if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None
else socket.getdefaulttimeout()
)
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()
_client_config: ClientConfig = None

system = platform.system().lower()
if system == "darwin":
Expand All @@ -154,14 +146,27 @@ class RemoteConnection:
extra_headers = None
user_agent = f"selenium/{__version__} (python {system})"

@classmethod
def get_client_config(cls):
""":Returns:
ClientConfig instance for the Remote Connection
"""
return cls._client_config

@classmethod
def get_timeout(cls):
""":Returns:
Timeout value in seconds for all http requests made to the
Remote Connection
"""
return None if cls._timeout == socket._GLOBAL_DEFAULT_TIMEOUT else cls._timeout
warnings.warn(
"get_timeout() in RemoteConnection is deprecated, get timeout from ClientConfig instance instead",
DeprecationWarning,
stacklevel=2,
)
return cls._client_config.timeout

@classmethod
def set_timeout(cls, timeout):
Expand All @@ -170,12 +175,22 @@ def set_timeout(cls, timeout):
:Args:
- timeout - timeout value for http requests in seconds
"""
cls._timeout = timeout
warnings.warn(
"set_timeout() in RemoteConnection is deprecated, set timeout to ClientConfig instance in constructor instead",
DeprecationWarning,
stacklevel=2,
)
cls._client_config.timeout = timeout

@classmethod
def reset_timeout(cls):
"""Reset the http request timeout to socket._GLOBAL_DEFAULT_TIMEOUT."""
cls._timeout = socket._GLOBAL_DEFAULT_TIMEOUT
warnings.warn(
"reset_timeout() in RemoteConnection is deprecated, use reset_timeout() in ClientConfig instance instead",
DeprecationWarning,
stacklevel=2,
)
cls._client_config.reset_timeout()

@classmethod
def get_certificate_bundle_path(cls):
Expand All @@ -185,7 +200,12 @@ def get_certificate_bundle_path(cls):
command executor. Defaults to certifi.where() or
REQUESTS_CA_BUNDLE env variable if set.
"""
return cls._ca_certs
warnings.warn(
"get_certificate_bundle_path() in RemoteConnection is deprecated, get ca_certs from ClientConfig instance instead",
DeprecationWarning,
stacklevel=2,
)
return cls._client_config.ca_certs

@classmethod
def set_certificate_bundle_path(cls, path):
Expand All @@ -196,7 +216,12 @@ def set_certificate_bundle_path(cls, path):
:Args:
- path - path of a .pem encoded certificate chain.
"""
cls._ca_certs = path
warnings.warn(
"set_certificate_bundle_path() in RemoteConnection is deprecated, set ca_certs to ClientConfig instance in constructor instead",
DeprecationWarning,
stacklevel=2,
)
cls._client_config.ca_certs = path

@classmethod
def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
Expand Down Expand Up @@ -239,15 +264,17 @@ def _separate_http_proxy_auth(self):
return proxy_without_auth, auth

def _get_connection_manager(self):
pool_manager_init_args = {"timeout": self.get_timeout()}
pool_manager_init_args.update(self._init_args_for_pool_manager.get("init_args_for_pool_manager", {}))
pool_manager_init_args = {"timeout": self._client_config.timeout}
pool_manager_init_args.update(
self._client_config.init_args_for_pool_manager.get("init_args_for_pool_manager", {})
)

if self._ignore_certificates:
if self._client_config.ignore_certificates:
pool_manager_init_args["cert_reqs"] = "CERT_NONE"
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
elif self._ca_certs:
elif self._client_config.ca_certs:
pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED"
pool_manager_init_args["ca_certs"] = self._ca_certs
pool_manager_init_args["ca_certs"] = self._client_config.ca_certs

if self._proxy_url:
if self._proxy_url.lower().startswith("sock"):
Expand All @@ -263,18 +290,21 @@ def _get_connection_manager(self):

def __init__(
self,
remote_server_addr: str,
remote_server_addr: Optional[str] = None,
keep_alive: Optional[bool] = True,
ignore_proxy: Optional[bool] = False,
ignore_certificates: Optional[bool] = False,
init_args_for_pool_manager: Optional[dict] = None,
client_config: Optional[ClientConfig] = None,
):
self.keep_alive = keep_alive
self._url = remote_server_addr
self._ignore_certificates = ignore_certificates
self._init_args_for_pool_manager = init_args_for_pool_manager or {}
self._client_config = client_config or ClientConfig(remote_server_addr, keep_alive)
self._client_config = client_config or ClientConfig(
remote_server_addr=remote_server_addr,
keep_alive=keep_alive,
ignore_certificates=ignore_certificates,
init_args_for_pool_manager=init_args_for_pool_manager,
)

RemoteConnection._client_config = self._client_config

if remote_server_addr:
warnings.warn(
Expand All @@ -290,6 +320,20 @@ def __init__(
stacklevel=2,
)

if ignore_certificates:
warnings.warn(
"setting ignore_certificates in RemoteConnection() is deprecated, set in ClientConfig instance instead",
DeprecationWarning,
stacklevel=2,
)

if init_args_for_pool_manager:
warnings.warn(
"setting init_args_for_pool_manager in RemoteConnection() is deprecated, set in ClientConfig instance instead",
DeprecationWarning,
stacklevel=2,
)

if ignore_proxy:
warnings.warn(
"setting ignore_proxy in RemoteConnection() is deprecated, set in ClientConfig instance instead",
Expand Down
Loading

0 comments on commit 6147a29

Please sign in to comment.