Skip to content

Commit

Permalink
Merge pull request #257 from dsgnr/concurrency
Browse files Browse the repository at this point in the history
Add threading to socket check
  • Loading branch information
dsgnr authored Nov 9, 2024
2 parents c0004aa + 4416c5d commit d442da9
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 90 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

## [Unreleased]
- Fixes incorrect environment variable defaults in README
- Adds threading to `query_ipv4` method. Uses default worker value (CPU count).
This will improve performance where there are more than one port that does not response,
thus reaching the timeout limit.

## [3.1.0] - 2024-11-08
- Bump gunicorn from 22.0.0 to 23.0.0
Expand Down
49 changes: 37 additions & 12 deletions backend/api/app/helpers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import socket
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
from urllib.parse import urlparse

Expand Down Expand Up @@ -89,6 +90,41 @@ def is_valid_hostname(hostname: str) -> bool:
raise ValueError("Hostname does not appear to resolve") from socket_err


def _check_port_status(address: str, port: int) -> list[str, int]:
"""Check if a specific port on the provided address is open.
Returns a dictionary with the port and the connection status.
"""
with socket.socket() as sock:
sock.settimeout(1) # Set a timeout of 1 second
result = sock.connect_ex((address, port)) # 0 if open, non-zero if closed
return {"port": port, "status": result == 0}


def check_ports(address: str, ports: dict[int]) -> list[dict[str, int]]:
"""Check multiple ports for the provided address with threading.
Args:
address (str): The IP address to check.
ports (list[int]): List of ports to check on the given address.
max_threads (int): Maximum number of threads to use. Default is 10.
Returns:
list[dict[str, int]]: A list of dictionaries containing port numbers and their statuses.
"""
results = []
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(_check_port_status, address, port): port for port in ports
}
for future in futures:
result = (
future.result()
) # Retrieve the result of each future as it completes
results.append(result)
return results


def query_ipv4(address: str, ports: list[int]) -> list[dict]:
"""
Checks whether the specified ports on a given IPv4 address or hostname are connectable.
Expand All @@ -115,18 +151,7 @@ def query_ipv4(address: str, ports: list[int]) -> list[dict]:
is_valid_hostname(address)
except Exception as ex:
raise JsonAPIException(key="host", message=str(ex)) from ex

results = []
for port in ports:
result = {"port": port, "status": False}
sock = socket.socket()
sock.settimeout(1)
port_check = sock.connect_ex((address, int(port)))
if port_check == 0:
result["status"] = True
sock.close()
results.append(result)
return results
return check_ports(address, ports)


def get_requester(request: Request) -> str:
Expand Down
51 changes: 25 additions & 26 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Fixtures for testing the application routes and helper functions"""

from unittest.mock import patch

import pytest
Expand All @@ -14,50 +15,48 @@
VALID_DOMAIN = "example.com"
INVALID_HOST = "foo"
LOCALHOST_IPV4 = "127.0.0.1"
PORTS = [80, 22, 8080, 443]
OPEN_PORTS = [80, 443]
CLOSED_PORTS = [22, 8080]
SOCKET_OPEN = 0
SOCKET_CLOSED = 1


@pytest.fixture
def client():
"""Fixture to provide a test client for app requests."""
return TestClient(app)


def mock_connect(address_port_tuple: tuple[str, int]) -> int:
"""Simulate mixed open/closed ports based on port numbers."""
return SOCKET_OPEN if address_port_tuple[1] in OPEN_PORTS else SOCKET_CLOSED


@pytest.fixture(autouse=True)
def mock_socket():
"""
Simulate the socket connection.
Uses the `mock_connect` method above to return the state value
"""
with patch("socket.socket.connect_ex", side_effect=mock_connect):
yield


@pytest.fixture
def mock_request_path():
"""Fixture to create a mock Request object with a specified path."""
return Request(scope={"method": "GET", "path": "/test-path"})


@pytest.fixture
def mock_socket():
"""Fixture to mock socket.socket calls."""
with patch("socket.socket") as sock:
yield sock


@pytest.fixture
def mock_request():
"""Fixture to create a mock request with customizable headers."""
class MockRequest: # pylint: disable=too-few-public-methods

class MockRequest: # pylint: disable=too-few-public-methods
"""The MockRequest class"""

def __init__(self, headers):
self.headers = headers

return MockRequest


@pytest.fixture
def mock_is_ip_address(mocker):
"""Fixture to mock the is_ip_address helper function."""
return mocker.patch("app.helpers.query.is_ip_address")


@pytest.fixture
def mock_is_address_valid(mocker):
"""Fixture to mock the is_address_valid helper function."""
return mocker.patch("app.helpers.query.is_address_valid")


@pytest.fixture
def mock_is_valid_hostname(mocker):
"""Fixture to mock the is_valid_hostname helper function."""
return mocker.patch("app.helpers.query.is_valid_hostname")
129 changes: 77 additions & 52 deletions backend/tests/test_query_ipv4.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,43 @@
"""Tests for query_ipv4"""
import socket
from unittest.mock import MagicMock, patch

import pytest

from api.app.helpers.query import JsonAPIException, query_ipv4

from .conftest import INVALID_HOST, VALID_PUBLIC_IPV4


def test_query_ipv4_single_open_port(mock_socket):
"""Mock socket connection to return 0, indicating the port is open"""
mock_sock_instance = MagicMock()
mock_sock_instance.connect_ex.return_value = 0
mock_socket.return_value = mock_sock_instance

ports = [80]
assert query_ipv4(VALID_PUBLIC_IPV4, ports) == [{"port": ports[0], "status": True}]


def test_query_ipv4_single_closed_port(mock_socket):
from api.app.helpers.query import (
JsonAPIException,
_check_port_status,
check_ports,
query_ipv4,
)

from .conftest import (
CLOSED_PORTS,
INVALID_HOST,
OPEN_PORTS,
PORTS,
SOCKET_OPEN,
VALID_PUBLIC_IPV4,
mock_connect,
)


def test_query_ipv4_single_closed_port():
"""Mock socket connection to return non-zero, indicating the port is closed"""
mock_sock_instance = MagicMock()
mock_sock_instance.connect_ex.return_value = 1
mock_socket.return_value = mock_sock_instance

ports = [81]
assert query_ipv4(VALID_PUBLIC_IPV4, ports) == [{"port": ports[0], "status": False}]


def test_query_ipv4_multiple_ports_mixed_status(mock_socket):
"""Simulate one open port and one closed port"""
mock_sock_instance = MagicMock()
mock_sock_instance.connect_ex.side_effect = [
0,
1,
] # Open for port 80, closed for port 443
mock_socket.return_value = mock_sock_instance

ports = [80, 443]
assert query_ipv4(VALID_PUBLIC_IPV4, ports) == [
{"port": ports[0], "status": True},
{"port": ports[1], "status": False},
result = query_ipv4(VALID_PUBLIC_IPV4, [CLOSED_PORTS[0]])
assert result == [{"port": CLOSED_PORTS[0], "status": False}]


def test_query_ipv4_multiple_ports_mixed_status():
"""Test when some ports are open and some are closed."""
result = query_ipv4(VALID_PUBLIC_IPV4, PORTS)
expected = [
{
"port": port,
"status": mock_connect((VALID_PUBLIC_IPV4, port)) == SOCKET_OPEN,
}
for port in PORTS
]
assert result == expected


def test_query_ipv4_empty_ports_list():
"""Test query_ipv4 returns empty list when ports list is empty."""
Expand All @@ -52,16 +47,46 @@ def test_query_ipv4_empty_ports_list():

def test_query_ipv4_invalid_address():
"""Test query_ipv4 raises JsonAPIException for an invalid hostname."""
with (
patch("socket.gethostbyname", side_effect=socket.gaierror),
pytest.raises(JsonAPIException, match=".*Hostname does not appear to resolve"),
):
query_ipv4(INVALID_HOST, [443])


def test_query_ipv4_valid_address(mock_socket):
"""Test query_ipv4 returns correct status for a valid IP and port."""
mock_sock_instance = MagicMock()
mock_sock_instance.connect_ex.return_value = 0
mock_socket.return_value = mock_sock_instance
assert query_ipv4(VALID_PUBLIC_IPV4, [443]) == [{"port": 443, "status": True}]
with pytest.raises(JsonAPIException, match=".*Hostname does not appear to resolve"):
query_ipv4(INVALID_HOST, [OPEN_PORTS[0]])


def test_check_ports_all_open():
"""Test when all ports are open."""
result = check_ports(VALID_PUBLIC_IPV4, OPEN_PORTS)
expected = [{"port": port, "status": True} for port in OPEN_PORTS]
assert result == expected


def test_check_ports_all_closed():
"""Test when all ports are closed."""
result = check_ports(VALID_PUBLIC_IPV4, CLOSED_PORTS)
expected = [{"port": port, "status": False} for port in CLOSED_PORTS]
assert result == expected


def test_check_ports_mixed():
"""Test when some ports are open and some are closed."""
result = check_ports(VALID_PUBLIC_IPV4, PORTS)
expected = [
{
"port": port,
"status": mock_connect((VALID_PUBLIC_IPV4, port)) == SOCKET_OPEN,
}
for port in PORTS
]
assert result == expected


def test_check_port_status_open():
"""Test _check_port_status with an open port."""
result = _check_port_status(VALID_PUBLIC_IPV4, OPEN_PORTS[0])
expected = {"port": OPEN_PORTS[0], "status": True}
assert result == expected


def test_check_port_status_closed():
"""Test _check_port_status with a closed port."""
result = _check_port_status(VALID_PUBLIC_IPV4, CLOSED_PORTS[0])
expected = {"port": CLOSED_PORTS[0], "status": False}
assert result == expected

0 comments on commit d442da9

Please sign in to comment.