diff --git a/py/selenium/webdriver/remote/client_config.py b/py/selenium/webdriver/remote/client_config.py index 62ba82076946b..f571d0886df54 100644 --- a/py/selenium/webdriver/remote/client_config.py +++ b/py/selenium/webdriver/remote/client_config.py @@ -17,6 +17,7 @@ import base64 import os import socket +from enum import Enum from typing import Optional from urllib import parse @@ -26,6 +27,12 @@ from selenium.webdriver.common.proxy import ProxyType +class AuthType(Enum): + BASIC = "Basic" + BEARER = "Bearer" + X_API_KEY = "X-API-Key" + + class ClientConfig: def __init__( self, @@ -38,8 +45,10 @@ def __init__( ca_certs: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, - auth_type: Optional[str] = "Basic", + auth_type: Optional[AuthType] = AuthType.BASIC, token: Optional[str] = None, + user_agent: Optional[str] = None, + extra_headers: Optional[dict] = None, ) -> None: self.remote_server_addr = remote_server_addr self.keep_alive = keep_alive @@ -51,6 +60,8 @@ def __init__( self.password = password self.auth_type = auth_type self.token = token + self.user_agent = user_agent + self.extra_headers = extra_headers self.timeout = ( ( @@ -198,14 +209,17 @@ def password(self, value: str) -> None: self._password = value @property - def auth_type(self) -> str: + def auth_type(self) -> AuthType: """Returns the type of authentication to the remote server.""" return self._auth_type @auth_type.setter - def auth_type(self, value: str) -> None: + def auth_type(self, value: AuthType) -> None: """Sets the type of authentication to the remote server if it is not - using basic with username and password.""" + using basic with username and password. + + :Args: value - AuthType enum value. For others, please use `extra_headers` instead + """ self._auth_type = value @property @@ -219,6 +233,26 @@ def token(self, value: str) -> None: auth_type is not basic.""" self._token = value + @property + def user_agent(self) -> str: + """Returns user agent to be added to the request headers.""" + return self._user_agent + + @user_agent.setter + def user_agent(self, value: str) -> None: + """Sets user agent to be added to the request headers.""" + self._user_agent = value + + @property + def extra_headers(self) -> dict: + """Returns extra headers to be added to the request.""" + return self._extra_headers + + @extra_headers.setter + def extra_headers(self, value: dict) -> None: + """Sets extra headers to be added to the request.""" + self._extra_headers = value + def get_proxy_url(self) -> Optional[str]: """Returns the proxy URL to use for the connection.""" proxy_type = self.proxy.proxy_type @@ -246,13 +280,12 @@ def get_proxy_url(self) -> Optional[str]: def get_auth_header(self) -> Optional[dict]: """Returns the authorization to add to the request headers.""" - auth_type = self.auth_type.lower() - if auth_type == "basic" and self.username and self.password: + if self.auth_type is AuthType.BASIC and self.username and self.password: credentials = f"{self.username}:{self.password}" encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - return {"Authorization": f"Basic {encoded_credentials}"} - if auth_type == "bearer" and self.token: - return {"Authorization": f"Bearer {self.token}"} - if auth_type == "oauth" and self.token: - return {"Authorization": f"OAuth {self.token}"} + return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"} + if self.auth_type is AuthType.BEARER and self.token: + return {"Authorization": f"{AuthType.BEARER.value} {self.token}"} + if self.auth_type is AuthType.X_API_KEY and self.token: + return {f"{AuthType.X_API_KEY.value}": f"{self.token}"} return None diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 04786c39c2673..5404c53a4139e 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -22,6 +22,7 @@ from base64 import b64encode from typing import Optional from urllib import parse +from urllib.parse import urlparse import urllib3 @@ -243,6 +244,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False): } if parsed_url.username: + warnings.warn( + "Embedding username and password in URL could be insecure, use ClientConfig instead", stacklevel=2 + ) base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode()) headers.update({"Authorization": f"Basic {base64string.decode()}"}) @@ -255,16 +259,14 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False): return headers def _identify_http_proxy_auth(self): - url = self._proxy_url - url = url[url.find(":") + 3 :] - return "@" in url and len(url[: url.find("@")]) > 0 + parsed_url = urlparse(self._proxy_url) + if parsed_url.username and parsed_url.password: + return True def _separate_http_proxy_auth(self): - url = self._proxy_url - protocol = url[: url.find(":") + 3] - no_protocol = url[len(protocol) :] - auth = no_protocol[: no_protocol.find("@")] - proxy_without_auth = protocol + no_protocol[len(auth) + 1 :] + parsed_url = urlparse(self._proxy_url) + proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}" + auth = f"{parsed_url.username}:{parsed_url.password}" return proxy_without_auth, auth def _get_connection_manager(self): @@ -312,6 +314,8 @@ def __init__( RemoteConnection._timeout = self._client_config.timeout RemoteConnection._ca_certs = self._client_config.ca_certs RemoteConnection._client_config = self._client_config + RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers + RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent if remote_server_addr: warnings.warn( diff --git a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py index ea6281607b4be..fb7e865c68b04 100644 --- a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py +++ b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py @@ -15,13 +15,19 @@ # specific language governing permissions and limitations # under the License. +import os from unittest.mock import patch from urllib import parse import pytest import urllib3 +from urllib3.util import Retry +from urllib3.util import Timeout from selenium import __version__ +from selenium.webdriver import Proxy +from selenium.webdriver.common.proxy import ProxyType +from selenium.webdriver.remote.client_config import AuthType from selenium.webdriver.remote.remote_connection import ClientConfig from selenium.webdriver.remote.remote_connection import RemoteConnection @@ -64,8 +70,13 @@ def test_get_remote_connection_headers_defaults(): def test_get_remote_connection_headers_adds_auth_header_if_pass(): url = "http://user:pass@remote" - headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url)) + with pytest.warns(None) as record: + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url)) assert headers.get("Authorization") == "Basic dXNlcjpwYXNz" + assert ( + record[0].message.args[0] + == "Embedding username and password in URL could be insecure, use ClientConfig instead" + ) def test_get_remote_connection_headers_adds_keep_alive_if_requested(): @@ -81,15 +92,33 @@ def test_get_proxy_url_http(mock_proxy_settings): assert proxy_url == proxy -def test_get_auth_header_if_client_config_pass(): +def test_get_auth_header_if_client_config_pass_basic_auth(): custom_config = ClientConfig( - remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic" + remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type=AuthType.BASIC ) remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config) headers = remote_connection._client_config.get_auth_header() assert headers.get("Authorization") == "Basic dXNlcjpwYXNz" +def test_get_auth_header_if_client_config_pass_bearer_token(): + custom_config = ClientConfig( + remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.BEARER, token="dXNlcjpwYXNz" + ) + remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config) + headers = remote_connection._client_config.get_auth_header() + assert headers.get("Authorization") == "Bearer dXNlcjpwYXNz" + + +def test_get_auth_header_if_client_config_pass_x_api_key(): + custom_config = ClientConfig( + remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.X_API_KEY, token="abcdefgh123456789" + ) + remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config) + headers = remote_connection._client_config.get_auth_header() + assert headers.get("X-API-Key") == "abcdefgh123456789" + + def test_get_proxy_url_https(mock_proxy_settings): proxy = "http://https_proxy.com:8080" remote_connection = RemoteConnection("https://remote", keep_alive=False) @@ -97,6 +126,62 @@ def test_get_proxy_url_https(mock_proxy_settings): assert proxy_url == proxy +def test_get_proxy_url_https_via_client_config(): + client_config = ClientConfig( + remote_server_addr="https://localhost:4444", + proxy=Proxy({"proxyType": ProxyType.MANUAL, "sslProxy": "https://admin:admin@http_proxy.com:8080"}), + ) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.ProxyManager) + conn.proxy_url = "https://http_proxy.com:8080" + conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin") + + +def test_get_proxy_url_http_via_client_config(): + client_config = ClientConfig( + remote_server_addr="http://localhost:4444", + proxy=Proxy( + { + "proxyType": ProxyType.MANUAL, + "httpProxy": "http://admin:admin@http_proxy.com:8080", + "sslProxy": "https://admin:admin@http_proxy.com:8080", + } + ), + ) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.ProxyManager) + conn.proxy_url = "http://http_proxy.com:8080" + conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin") + + +def test_get_proxy_direct_via_client_config(): + client_config = ClientConfig( + remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.DIRECT}) + ) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.PoolManager) + proxy_url = remote_connection._client_config.get_proxy_url() + assert proxy_url is None + + +def test_get_proxy_system_matches_no_proxy_via_client_config(): + os.environ["HTTP_PROXY"] = "http://admin:admin@system_proxy.com:8080" + os.environ["NO_PROXY"] = "localhost,127.0.0.1" + client_config = ClientConfig( + remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.SYSTEM}) + ) + remote_connection = RemoteConnection(client_config=client_config) + conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.PoolManager) + proxy_url = remote_connection._client_config.get_proxy_url() + assert proxy_url is None + os.environ.pop("HTTP_PROXY") + os.environ.pop("NO_PROXY") + + def test_get_proxy_url_none(mock_proxy_settings_missing): remote_connection = RemoteConnection("https://remote", keep_alive=False) proxy_url = remote_connection._client_config.get_proxy_url() @@ -295,6 +380,28 @@ def test_override_user_agent_in_headers(mock_get_remote_connection_headers, remo assert headers.get("Content-Type") == "application/json;charset=UTF-8" +@patch("selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers") +def test_override_user_agent_via_client_config(mock_get_remote_connection_headers): + client_config = ClientConfig( + remote_server_addr="http://localhost:4444", + user_agent="custom-agent/1.0 (python 3.8)", + extra_headers={"Content-Type": "application/xml;charset=UTF-8"}, + ) + remote_connection = RemoteConnection(client_config=client_config) + + mock_get_remote_connection_headers.return_value = { + "Accept": "application/json", + "Content-Type": "application/xml;charset=UTF-8", + "User-Agent": "custom-agent/1.0 (python 3.8)", + } + + headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444")) + + assert headers.get("User-Agent") == "custom-agent/1.0 (python 3.8)" + assert headers.get("Accept") == "application/json" + assert headers.get("Content-Type") == "application/xml;charset=UTF-8" + + @patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request") def test_register_extra_headers(mock_request, remote_connection): RemoteConnection.extra_headers = {"Foo": "bar"} @@ -307,6 +414,26 @@ def test_register_extra_headers(mock_request, remote_connection): assert headers["Foo"] == "bar" +@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request") +def test_register_extra_headers_via_client_config(mock_request): + client_config = ClientConfig( + remote_server_addr="http://localhost:4444", + extra_headers={ + "Authorization": "AWS4-HMAC-SHA256", + "Credential": "abc/20200618/us-east-1/execute-api/aws4_request", + }, + ) + remote_connection = RemoteConnection(client_config=client_config) + + mock_request.return_value = {"status": 200, "value": "OK"} + remote_connection.execute("newSession", {}) + + mock_request.assert_called_once_with("POST", "http://localhost:4444/session", body="{}") + headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"), False) + assert headers["Authorization"] == "AWS4-HMAC-SHA256" + assert headers["Credential"] == "abc/20200618/us-east-1/execute-api/aws4_request" + + def test_backwards_compatibility_with_appium_connection(): # Keep backward compatibility for AppiumConnection - https://github.com/SeleniumHQ/selenium/issues/14694 client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem", timeout=300) @@ -328,6 +455,8 @@ def test_get_connection_manager_with_timeout_from_client_config(): assert conn.connection_pool_kw["timeout"] == 10 assert isinstance(conn, urllib3.PoolManager) + +def test_connection_manager_with_timeout_via_client_config(): client_config = ClientConfig("http://remote", timeout=300) remote_connection = RemoteConnection(client_config=client_config) conn = remote_connection._get_connection_manager() @@ -335,7 +464,7 @@ def test_get_connection_manager_with_timeout_from_client_config(): assert isinstance(conn, urllib3.PoolManager) -def test_get_connection_manager_with_ca_certs_from_client_config(): +def test_get_connection_manager_with_ca_certs(): remote_connection = RemoteConnection(remote_server_addr="http://remote") remote_connection.set_certificate_bundle_path("/path/to/cacert.pem") conn = remote_connection._get_connection_manager() @@ -344,6 +473,8 @@ def test_get_connection_manager_with_ca_certs_from_client_config(): assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem" assert isinstance(conn, urllib3.PoolManager) + +def test_connection_manager_with_ca_certs_via_client_config(): client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem") remote_connection = RemoteConnection(client_config=client_config) conn = remote_connection._get_connection_manager() @@ -361,15 +492,17 @@ def test_get_connection_manager_ignores_certificates(): assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" assert isinstance(conn, urllib3.PoolManager) + remote_connection.reset_timeout() + assert remote_connection.get_timeout() is None + + +def test_connection_manager_ignores_certificates_via_client_config(): client_config = ClientConfig(remote_server_addr="http://remote", ignore_certificates=True, timeout=10) remote_connection = RemoteConnection(client_config=client_config) conn = remote_connection._get_connection_manager() + assert isinstance(conn, urllib3.PoolManager) assert conn.connection_pool_kw["timeout"] == 10 assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" - assert isinstance(conn, urllib3.PoolManager) - - remote_connection.reset_timeout() - assert remote_connection.get_timeout() is None def test_get_connection_manager_with_custom_args(): @@ -383,11 +516,16 @@ def test_get_connection_manager_with_custom_args(): assert conn.connection_pool_kw["retries"] == 3 assert conn.connection_pool_kw["block"] is True + +def test_connection_manager_with_custom_args_via_client_config(): + retries = Retry(connect=2, read=2, redirect=2) + timeout = Timeout(connect=300, read=3600) client_config = ClientConfig( - remote_server_addr="http://remote", keep_alive=False, init_args_for_pool_manager=custom_args + remote_server_addr="http://localhost:4444", + init_args_for_pool_manager={"init_args_for_pool_manager": {"retries": retries, "timeout": timeout}}, ) remote_connection = RemoteConnection(client_config=client_config) conn = remote_connection._get_connection_manager() assert isinstance(conn, urllib3.PoolManager) - assert conn.connection_pool_kw["retries"] == 3 - assert conn.connection_pool_kw["block"] is True + assert conn.connection_pool_kw["retries"] == retries + assert conn.connection_pool_kw["timeout"] == timeout