Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py] Set user_agent and extra_headers via ClientConfig #14718

Merged
merged 8 commits into from
Nov 9, 2024
33 changes: 30 additions & 3 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
Expand All @@ -51,6 +53,8 @@ def __init__(
self.password = password
self.auth_type = auth_type
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers

self.timeout = (
(
Expand Down Expand Up @@ -205,7 +209,10 @@ def auth_type(self) -> str:
@auth_type.setter
def auth_type(self, value: str) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password."""
using basic with username and password.

Support values: Bearer, X-API-Key. For others, please use `extra_headers` instead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be having these as constants that people can pass in? It will prevent typos and give people examples of what to us

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AutomatedTester, I updated it to use enum. This arg has a default value and is newly added, so I think it will not have a regression impact. Can you check if it is fine?

"""
self._auth_type = value

@property
Expand All @@ -219,6 +226,26 @@ def token(self, value: str) -> None:
auth_type is not basic."""
self._token = value

@property
def user_agent(self) -> str:
"""Returns user agent to be added to the request headers."""
return self._user_agent

@user_agent.setter
def user_agent(self, value: str) -> None:
"""Sets user agent to be added to the request headers."""
self._user_agent = value

@property
def extra_headers(self) -> dict:
"""Returns extra headers to be added to the request."""
return self._extra_headers

@extra_headers.setter
def extra_headers(self, value: dict) -> None:
"""Sets extra headers to be added to the request."""
self._extra_headers = value

def get_proxy_url(self) -> Optional[str]:
"""Returns the proxy URL to use for the connection."""
proxy_type = self.proxy.proxy_type
Expand Down Expand Up @@ -253,6 +280,6 @@ def get_auth_header(self) -> Optional[dict]:
return {"Authorization": f"Basic {encoded_credentials}"}
if auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
if auth_type == "oauth" and self.token:
return {"Authorization": f"OAuth {self.token}"}
if auth_type == "x-api-key" and self.token:
return {"X-API-Key": f"{self.token}"}
return None
20 changes: 12 additions & 8 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from base64 import b64encode
from typing import Optional
from urllib import parse
from urllib.parse import urlparse

import urllib3

Expand Down Expand Up @@ -243,6 +244,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
}

if parsed_url.username:
warnings.warn(
"Embedding username and password in URL could be insecure, use ClientConfig instead", stacklevel=2
)
base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode())
headers.update({"Authorization": f"Basic {base64string.decode()}"})

Expand All @@ -255,16 +259,14 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
return headers

def _identify_http_proxy_auth(self):
url = self._proxy_url
url = url[url.find(":") + 3 :]
return "@" in url and len(url[: url.find("@")]) > 0
parsed_url = urlparse(self._proxy_url)
if parsed_url.username and parsed_url.password:
return True

def _separate_http_proxy_auth(self):
url = self._proxy_url
protocol = url[: url.find(":") + 3]
no_protocol = url[len(protocol) :]
auth = no_protocol[: no_protocol.find("@")]
proxy_without_auth = protocol + no_protocol[len(auth) + 1 :]
parsed_url = urlparse(self._proxy_url)
proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
auth = f"{parsed_url.username}:{parsed_url.password}"
return proxy_without_auth, auth

def _get_connection_manager(self):
Expand Down Expand Up @@ -312,6 +314,8 @@ def __init__(
RemoteConnection._timeout = self._client_config.timeout
RemoteConnection._ca_certs = self._client_config.ca_certs
RemoteConnection._client_config = self._client_config
RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers
RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent

if remote_server_addr:
warnings.warn(
Expand Down
157 changes: 147 additions & 10 deletions py/test/unit/selenium/webdriver/remote/remote_connection_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
# specific language governing permissions and limitations
# under the License.

import os
from unittest.mock import patch
from urllib import parse

import pytest
import urllib3
from urllib3.util import Retry
from urllib3.util import Timeout

from selenium import __version__
from selenium.webdriver import Proxy
from selenium.webdriver.common.proxy import ProxyType
from selenium.webdriver.remote.remote_connection import ClientConfig
from selenium.webdriver.remote.remote_connection import RemoteConnection

Expand Down Expand Up @@ -64,8 +69,13 @@ def test_get_remote_connection_headers_defaults():

def test_get_remote_connection_headers_adds_auth_header_if_pass():
url = "http://user:pass@remote"
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
with pytest.warns(None) as record:
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"
assert (
record[0].message.args[0]
== "Embedding username and password in URL could be insecure, use ClientConfig instead"
)


def test_get_remote_connection_headers_adds_keep_alive_if_requested():
Expand All @@ -81,7 +91,7 @@ def test_get_proxy_url_http(mock_proxy_settings):
assert proxy_url == proxy


def test_get_auth_header_if_client_config_pass():
def test_get_auth_header_if_client_config_pass_basic_auth():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic"
)
Expand All @@ -90,13 +100,87 @@ def test_get_auth_header_if_client_config_pass():
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"


def test_get_auth_header_if_client_config_pass_bearer_token():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type="Bearer", token="dXNlcjpwYXNz"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("Authorization") == "Bearer dXNlcjpwYXNz"


def test_get_auth_header_if_client_config_pass_x_api_key():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type="X-API-Key", token="abcdefgh123456789"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("X-API-Key") == "abcdefgh123456789"


def test_get_proxy_url_https(mock_proxy_settings):
proxy = "http://https_proxy.com:8080"
remote_connection = RemoteConnection("https://remote", keep_alive=False)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url == proxy


def test_get_proxy_url_https_via_client_config():
client_config = ClientConfig(
remote_server_addr="https://localhost:4444",
proxy=Proxy({"proxyType": ProxyType.MANUAL, "sslProxy": "https://admin:admin@http_proxy.com:8080"}),
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.ProxyManager)
conn.proxy_url = "https://http_proxy.com:8080"
conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin")


def test_get_proxy_url_http_via_client_config():
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
proxy=Proxy(
{
"proxyType": ProxyType.MANUAL,
"httpProxy": "http://admin:admin@http_proxy.com:8080",
"sslProxy": "https://admin:admin@http_proxy.com:8080",
}
),
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.ProxyManager)
conn.proxy_url = "http://http_proxy.com:8080"
conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin")


def test_get_proxy_direct_via_client_config():
client_config = ClientConfig(
remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.DIRECT})
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url is None


def test_get_proxy_system_matches_no_proxy_via_client_config():
os.environ["HTTP_PROXY"] = "http://admin:admin@system_proxy.com:8080"
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
client_config = ClientConfig(
remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.SYSTEM})
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url is None
os.environ.pop("HTTP_PROXY")
os.environ.pop("NO_PROXY")


def test_get_proxy_url_none(mock_proxy_settings_missing):
remote_connection = RemoteConnection("https://remote", keep_alive=False)
proxy_url = remote_connection._client_config.get_proxy_url()
Expand Down Expand Up @@ -295,6 +379,28 @@ def test_override_user_agent_in_headers(mock_get_remote_connection_headers, remo
assert headers.get("Content-Type") == "application/json;charset=UTF-8"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers")
def test_override_user_agent_via_client_config(mock_get_remote_connection_headers):
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
user_agent="custom-agent/1.0 (python 3.8)",
extra_headers={"Content-Type": "application/xml;charset=UTF-8"},
)
remote_connection = RemoteConnection(client_config=client_config)

mock_get_remote_connection_headers.return_value = {
"Accept": "application/json",
"Content-Type": "application/xml;charset=UTF-8",
"User-Agent": "custom-agent/1.0 (python 3.8)",
}

headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"))

assert headers.get("User-Agent") == "custom-agent/1.0 (python 3.8)"
assert headers.get("Accept") == "application/json"
assert headers.get("Content-Type") == "application/xml;charset=UTF-8"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request")
def test_register_extra_headers(mock_request, remote_connection):
RemoteConnection.extra_headers = {"Foo": "bar"}
Expand All @@ -307,6 +413,26 @@ def test_register_extra_headers(mock_request, remote_connection):
assert headers["Foo"] == "bar"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request")
def test_register_extra_headers_via_client_config(mock_request):
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
extra_headers={
"Authorization": "AWS4-HMAC-SHA256",
"Credential": "abc/20200618/us-east-1/execute-api/aws4_request",
},
)
remote_connection = RemoteConnection(client_config=client_config)

mock_request.return_value = {"status": 200, "value": "OK"}
remote_connection.execute("newSession", {})

mock_request.assert_called_once_with("POST", "http://localhost:4444/session", body="{}")
headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"), False)
assert headers["Authorization"] == "AWS4-HMAC-SHA256"
assert headers["Credential"] == "abc/20200618/us-east-1/execute-api/aws4_request"


def test_backwards_compatibility_with_appium_connection():
# Keep backward compatibility for AppiumConnection - https://github.com/SeleniumHQ/selenium/issues/14694
client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem", timeout=300)
Expand All @@ -328,14 +454,16 @@ def test_get_connection_manager_with_timeout_from_client_config():
assert conn.connection_pool_kw["timeout"] == 10
assert isinstance(conn, urllib3.PoolManager)


def test_connection_manager_with_timeout_via_client_config():
client_config = ClientConfig("http://remote", timeout=300)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert conn.connection_pool_kw["timeout"] == 300
assert isinstance(conn, urllib3.PoolManager)


def test_get_connection_manager_with_ca_certs_from_client_config():
def test_get_connection_manager_with_ca_certs():
remote_connection = RemoteConnection(remote_server_addr="http://remote")
remote_connection.set_certificate_bundle_path("/path/to/cacert.pem")
conn = remote_connection._get_connection_manager()
Expand All @@ -344,6 +472,8 @@ def test_get_connection_manager_with_ca_certs_from_client_config():
assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem"
assert isinstance(conn, urllib3.PoolManager)


def test_connection_manager_with_ca_certs_via_client_config():
client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem")
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
Expand All @@ -361,15 +491,17 @@ def test_get_connection_manager_ignores_certificates():
assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE"
assert isinstance(conn, urllib3.PoolManager)

remote_connection.reset_timeout()
assert remote_connection.get_timeout() is None


def test_connection_manager_ignores_certificates_via_client_config():
client_config = ClientConfig(remote_server_addr="http://remote", ignore_certificates=True, timeout=10)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
assert conn.connection_pool_kw["timeout"] == 10
assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE"
assert isinstance(conn, urllib3.PoolManager)

remote_connection.reset_timeout()
assert remote_connection.get_timeout() is None


def test_get_connection_manager_with_custom_args():
Expand All @@ -383,11 +515,16 @@ def test_get_connection_manager_with_custom_args():
assert conn.connection_pool_kw["retries"] == 3
assert conn.connection_pool_kw["block"] is True


def test_connection_manager_with_custom_args_via_client_config():
retries = Retry(connect=2, read=2, redirect=2)
timeout = Timeout(connect=300, read=3600)
client_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=False, init_args_for_pool_manager=custom_args
remote_server_addr="http://localhost:4444",
init_args_for_pool_manager={"init_args_for_pool_manager": {"retries": retries, "timeout": timeout}},
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
assert conn.connection_pool_kw["retries"] == 3
assert conn.connection_pool_kw["block"] is True
assert conn.connection_pool_kw["retries"] == retries
assert conn.connection_pool_kw["timeout"] == timeout
Loading