Skip to content

Commit

Permalink
update ip info unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlyaGomaa committed Nov 28, 2024
1 parent 82fd2ac commit fe88b62
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 31 deletions.
32 changes: 17 additions & 15 deletions modules/ip_info/ip_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,31 +479,33 @@ def pre_main(self):
if ip := self.get_gateway_ip():
self.db.set_default_gateway("IP", ip)

def handle_new_ip(self, ip):
def handle_new_ip(self, ip: str):
try:
# make sure its a valid ip
ip_addr = ipaddress.ip_address(ip)
except ValueError:
# not a valid ip skip
return

if not ip_addr.is_multicast:
# Do we have cached info about this ip in redis?
# If yes, load it
cached_ip_info = self.db.get_ip_info(ip)
if not cached_ip_info:
cached_ip_info = {}
if ip_addr.is_multicast:
return

# Do we have cached info about this ip in redis?
# If yes, load it
cached_ip_info = self.db.get_ip_info(ip)
if not cached_ip_info:
cached_ip_info = {}

# Get the geocountry
if cached_ip_info == {} or "geocountry" not in cached_ip_info:
self.get_geocountry(ip)
# Get the geocountry
if cached_ip_info == {} or "geocountry" not in cached_ip_info:
self.get_geocountry(ip)

# only update the ASN for this IP if more than 1 month
# passed since last ASN update on this IP
if self.asn.should_update_asn(cached_ip_info):
self.asn.get_asn(ip, cached_ip_info)
# only update the ASN for this IP if more than 1 month
# passed since last ASN update on this IP
if self.asn.should_update_asn(cached_ip_info):
self.asn.get_asn(ip, cached_ip_info)

self.get_rdns(ip)
self.get_rdns(ip)

async def main(self):
if msg := self.get_msg("new_MAC"):
Expand Down
5 changes: 2 additions & 3 deletions tests/module_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from modules.ip_info.ip_info import IPInfo
from slips_files.common.slips_utils import utils
from slips_files.core.helpers.whitelist.whitelist import Whitelist
from tests.common_test_utils import do_nothing
from modules.virustotal.virustotal import VT
from managers.process_manager import ProcessManager
from managers.redis_manager import RedisManager
Expand Down Expand Up @@ -275,11 +274,11 @@ def create_input_obj(
termination_event=Mock(),
)
input.db = mock_db
input.mark_self_as_done_processing = do_nothing
input.mark_self_as_done_processing = Mock()
input.bro_timeout = 1
# override the print function to avoid broken pipes
input.print = Mock()
input.stop_queues = do_nothing
input.stop_queues = Mock()
input.testing = True

return input
Expand Down
3 changes: 2 additions & 1 deletion tests/test_asn_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def test_get_cached_asn(ip_address, first_octet, cached_data, expected_result):
)
def test_update_asn(cached_data, update_period, expected_result):
asn_info = ModuleFactory().create_asn_obj()
result = asn_info.update_asn(cached_data, update_period)
asn_info.update_period = update_period
result = asn_info.should_update_asn(cached_data)
assert result == expected_result


Expand Down
37 changes: 25 additions & 12 deletions tests/test_ip_info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Unit test for modules/ip_info/ip_info.py"""

import asyncio

from tests.module_factory import ModuleFactory
import maxminddb
import pytest
from unittest.mock import Mock, patch
from unittest.mock import (
Mock,
patch,
)
import json
import requests
import socket
Expand Down Expand Up @@ -366,14 +371,20 @@ def test_get_vendor_online(

assert vendor == expected_vendor
mock_requests.assert_called_once_with(
"https://api.macvendors.com/00:11:22:33:44:55", timeout=5
"https://api.macvendors.com/00:11:22:33:44:55", timeout=2
)


def test_shutdown_gracefully(
async def tmp_function():
# Simulating some asynchronous work
await asyncio.sleep(1)


async def test_shutdown_gracefully(
mocker,
):
ip_info = ModuleFactory().create_ip_info_obj()
ip_info.reading_mac_db_task = tmp_function()

mock_asn_db = mocker.Mock()
mock_country_db = mocker.Mock()
Expand All @@ -383,7 +394,8 @@ def test_shutdown_gracefully(
ip_info.country_db = mock_country_db
ip_info.mac_db = mock_mac_db

ip_info.shutdown_gracefully()
await ip_info.shutdown_gracefully()

mock_asn_db.close.assert_called_once()
mock_country_db.close.assert_called_once()
mock_mac_db.close.assert_called_once()
Expand Down Expand Up @@ -446,8 +458,7 @@ def test_handle_new_ip(mocker, ip, is_multicast, cached_info, expected_calls):
mock_get_geocountry = mocker.patch.object(ip_info, "get_geocountry")
mock_get_asn = mocker.patch.object(ip_info.asn, "get_asn")
mock_get_rdns = mocker.patch.object(ip_info, "get_rdns")

mocker.patch.object(ip_info.asn, "update_asn", return_value=True)
ip_info.asn.update_asn = Mock(return_value=True)
ip_info.handle_new_ip(ip)
assert mock_get_geocountry.call_count == expected_calls.get(
"get_geocountry", 0
Expand All @@ -468,11 +479,13 @@ def test_check_if_we_have_pending_mac_queries_with_mac_db(
("AA:BB:CC:DD:EE:FF", "profile_2"),
Exception("Empty queue"),
]
mock_get_vendor = mocker.patch.object(ip_info, "get_vendor")
ip_info.check_if_we_have_pending_mac_queries()
assert mock_get_vendor.call_count == 2
mock_get_vendor.assert_any_call("00:11:22:33:44:55", "profile_1")
mock_get_vendor.assert_any_call("AA:BB:CC:DD:EE:FF", "profile_2")
mock_get_vendor_offline = mocker.patch.object(
ip_info, "get_vendor_offline"
)
ip_info.check_if_we_have_pending_offline_mac_queries()
assert mock_get_vendor_offline.call_count == 2
mock_get_vendor_offline.assert_any_call("00:11:22:33:44:55", "profile_1")
mock_get_vendor_offline.assert_any_call("AA:BB:CC:DD:EE:FF", "profile_2")


def test_check_if_we_have_pending_mac_queries_empty_queue(
Expand All @@ -483,7 +496,7 @@ def test_check_if_we_have_pending_mac_queries_empty_queue(
ip_info.pending_mac_queries = Mock()
ip_info.pending_mac_queries.empty.return_value = True
mock_get_vendor = mocker.patch.object(ip_info, "get_vendor")
ip_info.check_if_we_have_pending_mac_queries()
ip_info.check_if_we_have_pending_offline_mac_queries()
mock_get_vendor.assert_not_called()


Expand Down

0 comments on commit fe88b62

Please sign in to comment.