From 7675f5375beee3d074a43b732af3a42e8884da4e Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 15:12:53 +0200 Subject: [PATCH 1/7] conn.py: move the logic for interface timeout to its own function --- modules/flowalerts/conn.py | 52 +++++++++++++++---------------- slips_files/common/slips_utils.py | 1 - 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index b908783e2..5cac7963e 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -28,21 +28,21 @@ def init(self): # thread (we waited for the dns resolution for these connections) self.connections_checked_in_conn_dns_timer_thread = [] self.whitelist = self.flowalerts.whitelist - # Threshold how much time to wait when capturing in an interface, - # to start reporting connections without DNS - # Usually the computer resolved DNS already, so we need to wait a little to report + # how much time to wait when running on interface before reporting + # connections without DNS? Usually the computer resolved DNS + # already, so we need to wait a little to report # In mins self.conn_without_dns_interface_wait_time = 30 self.dns_analyzer = DNS(self.db, flowalerts=self) self.is_running_non_stop: bool = self.db.is_running_non_stop() self.classifier = FlowClassifier() + self.our_ips = utils.get_own_ips() def read_configuration(self): conf = ConfigParser() self.long_connection_threshold = conf.long_connection_threshold() self.data_exfiltration_threshold = conf.data_exfiltration_threshold() self.data_exfiltration_threshold = conf.data_exfiltration_threshold() - self.our_ips = utils.get_own_ips() self.client_ips: List[str] = conf.client_ips() def name(self) -> str: @@ -377,8 +377,6 @@ def should_ignore_conn_without_dns(self, flow) -> bool: """ checks for the cases that we should ignore the connection without dns """ - # we should ignore this evidence if the ip is ours, whether it's a - # private ip or in the list of client_ips return ( flow.type_ != "conn" or flow.appproto in ("dns", "icmp") @@ -391,6 +389,10 @@ def should_ignore_conn_without_dns(self, flow) -> bool: # because there's no dns.log to know if the dns was made or self.db.get_input_type() == "zeek_log_file" or self.db.is_doh_server(flow.daddr) + # connection without dns in case of an interface, + # should only be detected from the srcip of this device, + # not all ips, to avoid so many alerts of this type when port scanning + or (self.is_running_non_stop and flow.saddr not in self.our_ips) ) def check_if_resolution_was_made_by_different_version( @@ -416,6 +418,22 @@ def check_if_resolution_was_made_by_different_version( pass return False + def is_interface_timeout_reached(self, flow) -> bool: + """ + To avoid false positives in case of an interface + don't alert ConnectionWithoutDNS until 30 minutes has passed after + starting slips because the dns may have happened before starting slips + """ + if not self.is_running_non_stop: + # no timeout + return True + + start_time = self.db.get_slips_start_time() + now = datetime.now() + diff = utils.get_time_diff(start_time, now, return_type="minutes") + # less than 30 minutes have passed? + return diff >= self.conn_without_dns_interface_wait_time + async def check_connection_without_dns_resolution( self, profileid, twid, flow ) -> bool: @@ -434,26 +452,8 @@ async def check_connection_without_dns_resolution( # We dont have yet the dhcp in the redis, when is there check it # if self.db.get_dhcp_servers(daddr): # continue - - # To avoid false positives in case of an interface - # don't alert ConnectionWithoutDNS - # until 30 minutes has passed - # after starting slips because the dns may have happened before - # starting slips - if self.is_running_non_stop: - # connection without dns in case of an interface, - # should only be detected from the srcip of this device, - # not all ips, to avoid so many alerts of this type when port scanning - saddr = profileid.split("_")[-1] - if saddr not in self.our_ips: - return False - - start_time = self.db.get_slips_start_time() - now = datetime.now() - diff = utils.get_time_diff(start_time, now, return_type="minutes") - if diff < self.conn_without_dns_interface_wait_time: - # less than 30 minutes have passed - return False + if not self.is_interface_timeout_reached(flow): + return False # search 24hs back for a dns resolution if self.db.is_ip_resolved(flow.daddr, 24): diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index 473c15cd2..bc34d8935 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -415,7 +415,6 @@ def is_ignored_ip(self, ip: str) -> bool: except (ipaddress.AddressValueError, ValueError): return True # Is the IP multicast, private? (including localhost) - # local_link or reserved? # The broadcast address 255.255.255.255 is reserved. return bool( ( From 4cb7b22da5150032446b65d7514a9cb73272c6d6 Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 15:19:32 +0200 Subject: [PATCH 2/7] dns.py: wait 30 mins before the first "dns without conn" evidence --- modules/flowalerts/conn.py | 2 +- modules/flowalerts/dns.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index 5cac7963e..1afaf26b7 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -431,7 +431,7 @@ def is_interface_timeout_reached(self, flow) -> bool: start_time = self.db.get_slips_start_time() now = datetime.now() diff = utils.get_time_diff(start_time, now, return_type="minutes") - # less than 30 minutes have passed? + # 30 minutes have passed? return diff >= self.conn_without_dns_interface_wait_time async def check_connection_without_dns_resolution( diff --git a/modules/flowalerts/dns.py b/modules/flowalerts/dns.py index 30f889f83..c20f951dd 100644 --- a/modules/flowalerts/dns.py +++ b/modules/flowalerts/dns.py @@ -2,6 +2,7 @@ import collections import json import math +from datetime import datetime from typing import List import validators @@ -30,7 +31,11 @@ def init(self): self.dns_arpa_queries = {} # after this number of arpa queries, slips will detect an arpa scan self.arpa_scan_threshold = 10 + self.is_running_non_stop: bool = self.db.is_running_non_stop() self.classifier = FlowClassifier() + self.our_ips = utils.get_own_ips() + # In mins + self.dns_without_conn_interface_wait_time = 30 def name(self) -> str: return "DNS_analyzer" @@ -39,8 +44,7 @@ def read_configuration(self): conf = ConfigParser() self.shannon_entropy_threshold = conf.get_entropy_threshold() - @staticmethod - def should_detect_dns_without_conn(flow) -> bool: + def should_detect_dns_without_conn(self, flow) -> bool: """ returns False in the following cases - All reverse dns resolutions @@ -65,6 +69,10 @@ def should_detect_dns_without_conn(flow) -> bool: or flow.query == "WPAD" or flow.rcode_name != "NOERROR" or not flow.answers + # dns without conn in case of an interface, + # should only be detected from the srcip of this device, + # not all ips, to avoid so many alerts of this type when port scanning + or (self.is_running_non_stop and flow.saddr not in self.our_ips) ): return False return True @@ -216,6 +224,22 @@ def is_any_flow_answer_contacted(self, profileid, twid, flow) -> bool: # this is not a DNS without resolution return True + def is_interface_timeout_reached(self, flow): + """ + To avoid false positives in case of an interface + don't alert ConnectionWithoutDNS until 30 minutes has passed after + starting slips because the dns may have happened before starting slips + """ + if not self.is_running_non_stop: + # no timeout + return True + + start_time = self.db.get_slips_start_time() + now = datetime.now() + diff = utils.get_time_diff(start_time, now, return_type="minutes") + # 30 minutes have passed? + return diff >= self.dns_without_conn_interface_wait_time + async def check_dns_without_connection( self, profileid, twid, flow ) -> bool: @@ -225,6 +249,9 @@ async def check_dns_without_connection( if not self.should_detect_dns_without_conn(flow): return False + if not self.is_interface_timeout_reached(flow): + return False + if self.is_any_flow_answer_contacted(profileid, twid, flow): return False From 3608438d04af52a01a5b6e991696737d5bea0757 Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 15:36:50 +0200 Subject: [PATCH 3/7] database.py: dont store slips start time in the local timezone, store it as unix timestamp to avoid datetime errors --- slips_files/core/database/redis_db/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 7559d2329..114668919 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -206,7 +206,7 @@ def set_slips_internal_time(cls, timestamp): def get_slips_start_time(cls) -> str: """get the time slips started""" if start_time := cls.r.get("slips_start_time"): - start_time = utils.convert_format(start_time, utils.alerts_format) + start_time = utils.convert_format(start_time, "unixtimestamp") return start_time @classmethod From f2728daf4b7d47c4fd5466263cc29ca2f65be931 Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 15:56:00 +0200 Subject: [PATCH 4/7] update existing and add more unit tests --- modules/flowalerts/conn.py | 4 ++-- modules/flowalerts/dns.py | 4 ++-- tests/test_conn.py | 41 ++++++++++++++++++++++++++++---------- tests/test_dns.py | 1 + 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index 1afaf26b7..bd0c4e4ad 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -418,7 +418,7 @@ def check_if_resolution_was_made_by_different_version( pass return False - def is_interface_timeout_reached(self, flow) -> bool: + def is_interface_timeout_reached(self) -> bool: """ To avoid false positives in case of an interface don't alert ConnectionWithoutDNS until 30 minutes has passed after @@ -452,7 +452,7 @@ async def check_connection_without_dns_resolution( # We dont have yet the dhcp in the redis, when is there check it # if self.db.get_dhcp_servers(daddr): # continue - if not self.is_interface_timeout_reached(flow): + if not self.is_interface_timeout_reached(): return False # search 24hs back for a dns resolution diff --git a/modules/flowalerts/dns.py b/modules/flowalerts/dns.py index c20f951dd..7c50180c0 100644 --- a/modules/flowalerts/dns.py +++ b/modules/flowalerts/dns.py @@ -224,7 +224,7 @@ def is_any_flow_answer_contacted(self, profileid, twid, flow) -> bool: # this is not a DNS without resolution return True - def is_interface_timeout_reached(self, flow): + def is_interface_timeout_reached(self): """ To avoid false positives in case of an interface don't alert ConnectionWithoutDNS until 30 minutes has passed after @@ -249,7 +249,7 @@ async def check_dns_without_connection( if not self.should_detect_dns_without_conn(flow): return False - if not self.is_interface_timeout_reached(flow): + if not self.is_interface_timeout_reached(): return False if self.is_any_flow_answer_contacted(profileid, twid, flow): diff --git a/tests/test_conn.py b/tests/test_conn.py index 6b19eb2ec..145cf6e9d 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -1,9 +1,12 @@ """Unit test for modules/flowalerts/conn.py""" +from slips_files.common.slips_utils import utils from slips_files.core.flows.zeek import Conn from tests.module_factory import ModuleFactory import json -from unittest.mock import Mock +from unittest.mock import ( + Mock, +) import pytest from ipaddress import ip_address @@ -357,21 +360,36 @@ def test_check_data_upload( assert mock_set_evidence.call_count == expected_call_count +@pytest.mark.parametrize( + "mock_time_diff, expected_result", + [ + (40, True), # Timeout reached + (20, False), # Timeout not reached + ], +) +def test_is_interface_timeout_reached(mock_time_diff, expected_result): + conn = ModuleFactory().create_conn_analyzer_obj() + conn.is_running_non_stop = True + conn.conn_without_dns_interface_wait_time = 30 + utils.get_time_diff = Mock(return_value=mock_time_diff) + assert conn.is_interface_timeout_reached() == expected_result + + @pytest.mark.parametrize( "flow_type, appproto, daddr, input_type, " "is_doh_server, is_dns_server, " "client_ips, expected_result", [ - # Testcase 1: Not a 'conn' flow type - ("dns", "dns", "8.8.8.8", "pcap", False, False, [], True), - # Testcase 2: DNS application protocol - ("conn", "dns", "8.8.8.8", "pcap", False, False, [], True), - # Testcase 3: Ignored IP - ("conn", "http", "192.168.1.1", "pcap", False, False, [], True), - # Testcase 4: Client IP - ("conn", "http", "10.0.0.1", "pcap", False, False, ["10.0.0.1"], True), - # Testcase 5: DoH server - ("conn", "http", "1.1.1.1", "pcap", True, False, [], True), + # # Testcase 1: Not a 'conn' flow type + # ("dns", "dns", "8.8.8.8", "pcap", False, False, [], True), + # # Testcase 2: DNS application protocol + # ("conn", "dns", "8.8.8.8", "pcap", False, False, [], True), + # # Testcase 3: Ignored IP + # ("conn", "http", "192.168.1.1", "pcap", False, False, [], True), + # # Testcase 4: Client IP + # ("conn", "http", "10.0.0.1", "pcap", False, False, ["10.0.0.1"], True), + # # Testcase 5: DoH server + # ("conn", "http", "1.1.1.1", "pcap", True, False, [], True), # Testcase 7: Should not ignore ("conn", "http", "93.184.216.34", "pcap", False, False, [], False), ], @@ -390,6 +408,7 @@ def test_should_ignore_conn_without_dns( """Tests the should_ignore_conn_without_dns function with various scenarios.""" conn = ModuleFactory().create_conn_analyzer_obj() + conn.is_running_non_stop = False flow = Conn( starttime="1726249372.312124", uid=uid, diff --git a/tests/test_dns.py b/tests/test_dns.py index 386cb3bbb..077c658b0 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -32,6 +32,7 @@ ) def test_should_detect_dns_without_conn(domain, rcode_name, expected_result): dns = ModuleFactory().create_dns_analyzer_obj() + dns.is_running_non_stop = False flow = DNS( starttime="1726568479.5997488", uid="1234", From bbefbdb27efe8fce962b2667075f64f55c102a7a Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 16:00:54 +0200 Subject: [PATCH 5/7] update the docs --- docs/flowalerts.md | 4 +++- modules/flowalerts/dns.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/flowalerts.md b/docs/flowalerts.md index 41e1f7905..a797a2493 100644 --- a/docs/flowalerts.md +++ b/docs/flowalerts.md @@ -66,7 +66,7 @@ so we simply ignore alerts of this type when connected to well known organizatio Slips uses it's own lists of organizations and information about them (IPs, IP ranges, domains, and ASNs). They are stored in ```slips_files/organizations_info``` and they are used to check whether the IP/domain of each flow belong to a known org or not. Slips doesn't detect 'connection without DNS' when running -on an interface except for when it's done by this instance's own IP. +on an interface except for when it's done by this instance's own IP and only after 30 minutes has passed to avoid false positives (assuming the DNS resolution of these connections did happen before slips started). check [DoH section](https://stratospherelinuxips.readthedocs.io/en/develop/detection_modules.html#detect-doh) of the docs for info on how slips detects DoH. @@ -91,6 +91,8 @@ The domains that are excepted are: - Ignore WPAD domain from Windows - Ignore domains without a TLD such as the Chrome test domains. +Slips doesn't detect 'DNS resolutions without a connection' when running +on an interface except for when it's done by this instance's own IP and only after 5 minutes has passed to avoid false positives (assuming the connection did happen and yet to be logged). ## Connection to unknown ports diff --git a/modules/flowalerts/dns.py b/modules/flowalerts/dns.py index 7c50180c0..e028f6034 100644 --- a/modules/flowalerts/dns.py +++ b/modules/flowalerts/dns.py @@ -35,7 +35,7 @@ def init(self): self.classifier = FlowClassifier() self.our_ips = utils.get_own_ips() # In mins - self.dns_without_conn_interface_wait_time = 30 + self.dns_without_conn_interface_wait_time = 5 def name(self) -> str: return "DNS_analyzer" From fae67b267a747c8f93ad835cdcb6ae52d51f9982 Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 22:16:42 +0200 Subject: [PATCH 6/7] Fix the problem of comparing timezone aware with timezone naive dates --- managers/metadata_manager.py | 2 +- managers/process_manager.py | 5 ++--- modules/update_manager/update_manager.py | 10 +++++----- slips_files/common/slips_utils.py | 4 ++-- slips_files/core/database/redis_db/database.py | 9 +++------ 5 files changed, 13 insertions(+), 17 deletions(-) diff --git a/managers/metadata_manager.py b/managers/metadata_manager.py index 82e3aa77b..ae5bec310 100644 --- a/managers/metadata_manager.py +++ b/managers/metadata_manager.py @@ -74,7 +74,7 @@ def set_analysis_end_date(self, end_date): """ if not self.main.conf.enable_metadata(): return - + end_date = utils.convert_format(datetime.now(), utils.alerts_format) self.main.db.set_input_metadata({"analysis_end": end_date}) # add slips end date in the metadata dir diff --git a/managers/process_manager.py b/managers/process_manager.py index c3a007717..f7038c8b9 100644 --- a/managers/process_manager.py +++ b/managers/process_manager.py @@ -475,7 +475,7 @@ def get_analysis_time(self) -> Tuple[str, str]: returns analysis_time in minutes and slips end_time as a date """ start_time = self.main.db.get_slips_start_time() - end_time = utils.convert_format(datetime.now(), utils.alerts_format) + end_time = utils.convert_format(datetime.now(), "unixtimestamp") return ( utils.get_time_diff(start_time, end_time, return_type="minutes"), end_time, @@ -711,11 +711,10 @@ def shutdown_gracefully(self): if self.main.conf.export_labeled_flows(): format_ = self.main.conf.export_labeled_flows_to().lower() self.main.db.export_labeled_flows(format_) - + self.main.profilers_manager.cpu_profiler_release() self.main.profilers_manager.memory_profiler_release() - # if store_a_copy_of_zeek_files is set to yes in slips.yaml # copy the whole zeek_files dir to the output dir self.main.store_zeek_dir_copy() diff --git a/modules/update_manager/update_manager.py b/modules/update_manager/update_manager.py index fe5c90413..d06bc6cff 100644 --- a/modules/update_manager/update_manager.py +++ b/modules/update_manager/update_manager.py @@ -912,17 +912,17 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: return False for ioc in file: - date = ioc["InsertDate"] - date = utils.convert_ts_to_tz_aware(date) - diff = utils.get_time_diff( - date, time.time(), return_type="days" - ) + date = utils.convert_ts_to_tz_aware(ioc["InsertDate"]) + now = utils.convert_ts_to_tz_aware(time.time()) + diff = utils.get_time_diff(date, now, return_type="days") if diff > self.interval: continue + domain = ioc["DomainAddress"] if not utils.is_valid_domain(domain): continue + malicious_domains_dict[domain] = json.dumps( { "description": "", diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index bc34d8935..9a3127938 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -265,7 +265,7 @@ def convert_format(self, ts, required_format: str): # convert to the req format if required_format == "iso": - return datetime_obj.astimezone(self.local_tz).isoformat() + return datetime_obj.astimezone().isoformat() elif required_format == "unixtimestamp": return datetime_obj.timestamp() else: @@ -302,7 +302,7 @@ def convert_to_datetime(self, ts): given_format = self.get_time_format(ts) return ( - datetime.fromtimestamp(float(ts), tz=self.local_tz) + datetime.fromtimestamp(float(ts)) if given_format == "unixtimestamp" else datetime.strptime(ts, given_format) ) diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 114668919..195f1a6de 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -15,7 +15,6 @@ import time import json import subprocess -from datetime import datetime import ipaddress import sys import validators @@ -204,10 +203,8 @@ def set_slips_internal_time(cls, timestamp): @classmethod def get_slips_start_time(cls) -> str: - """get the time slips started""" - if start_time := cls.r.get("slips_start_time"): - start_time = utils.convert_format(start_time, "unixtimestamp") - return start_time + """get the time slips started in unix format""" + return cls.r.get("slips_start_time") @classmethod def init_redis_server(cls) -> Tuple[bool, str]: @@ -363,7 +360,7 @@ def change_redis_limits(cls, client: redis.StrictRedis): @classmethod def _set_slips_start_time(cls): """store the time slips started (datetime obj)""" - now = utils.convert_format(datetime.now(), utils.alerts_format) + now = time.time() cls.r.set("slips_start_time", now) def publish(self, channel, msg): From 2862508d3c44b245fe79ee2d1c8b66b6e780f31c Mon Sep 17 00:00:00 2001 From: alya Date: Thu, 14 Nov 2024 23:49:51 +0200 Subject: [PATCH 7/7] update the database MAC address unit tests --- .../core/database/redis_db/profile_handler.py | 8 +- tests/module_factory.py | 3 +- tests/test_database.py | 91 ++++++++++++++----- 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/slips_files/core/database/redis_db/profile_handler.py b/slips_files/core/database/redis_db/profile_handler.py index ccaa4b8fb..d454c33f0 100644 --- a/slips_files/core/database/redis_db/profile_handler.py +++ b/slips_files/core/database/redis_db/profile_handler.py @@ -1229,8 +1229,8 @@ def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): return False # get the ips that belong to this mac - cached_ip = self.r.hmget("MAC", mac_addr)[0] - if not cached_ip: + cached_ips: Optional[List] = self.r.hmget("MAC", mac_addr)[0] + if not cached_ips: # no mac info stored for profileid ip = json.dumps([incoming_ip]) self.r.hset("MAC", mac_addr, ip) @@ -1241,10 +1241,10 @@ def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): else: # we found another profile that has the same mac as this one # get all the ips, v4 and 6, that are stored with this mac - cached_ips = json.loads(cached_ip) + cached_ips: List[str] = json.loads(cached_ips) # get the last one of them found_ip = cached_ips[-1] - cached_ips = set(cached_ips) + cached_ips: Set[str] = set(cached_ips) if incoming_ip in cached_ips: # this is the case where we have the given ip already diff --git a/tests/module_factory.py b/tests/module_factory.py index 5cc644197..d60d54de6 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -130,8 +130,9 @@ def create_db_manager_obj( flush_db=flush_db, start_redis_server=start_redis_server, ) - db.r = db.rdb.r db.print = Mock() + # for easier access to redis db + db.r = db.rdb.r assert db.get_used_redis_port() == port return db diff --git a/tests/test_database.py b/tests/test_database.py index 84a5b0391..91d737b37 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,3 +1,8 @@ +from unittest.mock import ( + Mock, + call, +) + import redis import json import time @@ -178,38 +183,80 @@ def test_profile_moddule_labels(): assert labels["test"] == "malicious" -def test_add_mac_addr_to_profile(): +def test_add_mac_addr_with_new_ipv4(): + """ + adding an ipv4 to no cached ip + """ db = ModuleFactory().create_db_manager_obj( get_random_port(), flush_db=True ) ipv4 = "192.168.1.5" profileid_ipv4 = f"profile_{ipv4}" mac_addr = "00:00:5e:00:53:af" - # first associate this ip with some mac + + db.rdb.is_gw_mac = Mock(return_value=False) + db.r.hget = Mock() + db.r.hset = Mock() + db.r.hmget = Mock(return_value=[None]) # No entry initially + + # simulate adding a new MAC and IPv4 address assert db.add_mac_addr_to_profile(profileid_ipv4, mac_addr) is True - assert ipv4 in str(db.r.hget("MAC", mac_addr)) - # now claim that we found another profile - # that has the same mac as this one - # both ipv4 - profileid = "profile_192.168.1.6" - assert db.add_mac_addr_to_profile(profileid, mac_addr) is False - # this ip shouldnt be added to the profile as they're both ipv4 - assert "192.168.1.6" not in db.r.hget("MAC", mac_addr) + # Ensure the IP is associated in the 'MAC' hash + db.r.hmget.assert_called_with("MAC", mac_addr) + db.r.hset.assert_any_call("MAC", mac_addr, json.dumps([ipv4])) + + +def test_add_mac_addr_with_existing_ipv4(): + """ + adding an ipv4 to a cached ipv4 + """ + db = ModuleFactory().create_db_manager_obj( + get_random_port(), flush_db=True + ) + ipv4 = "192.168.1.5" + mac_addr = "00:00:5e:00:53:af" + db.rdb.is_gw_mac = Mock(return_value=False) + db.r.hget = Mock() + db.r.hset = Mock() + db.r.hmget = Mock(return_value=[json.dumps([ipv4])]) + + new_profile = "profile_192.168.1.6" + + # try to add a new profile with the same MAC but another IPv4 address + assert db.add_mac_addr_to_profile(new_profile, mac_addr) is False + + +def test_add_mac_addr_with_ipv6_association(): + """ + adding an ipv6 to a cached ipv4 + """ + db = ModuleFactory().create_db_manager_obj( + get_random_port(), flush_db=True + ) + ipv4 = "192.168.1.5" + profile_ipv4 = "profile_192.168.1.5" + mac_addr = "00:00:5e:00:53:af" + + # mock existing entry with ipv6 + db.rdb.is_gw_mac = Mock(return_value=False) + db.rdb.update_mac_of_profile = Mock() + db.r.hmget = Mock(return_value=[json.dumps([ipv4])]) + db.r.hset = Mock() + db.r.hget = Mock() - # now claim that another ipv6 has this mac ipv6 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334" - profileid_ipv6 = f"profile_{ipv6}" - db.add_mac_addr_to_profile(profileid_ipv6, mac_addr) - # make sure the mac is associated with his ipv6 - assert ipv6 in db.r.hget("MAC", mac_addr) - # make sure the ipv4 is associated with this - # ipv6 profile - assert ipv4 in db.get_ipv4_from_profile(profileid_ipv6) - - # make sure the ipv6 is associated with the - # profile that has the same ipv4 as the mac - assert ipv6 in str(db.r.hmget(profileid_ipv4, "IPv6")) + profile_ipv6 = f"profile_{ipv6}" + # try to associate an ipv6 with the same MAC address + assert db.add_mac_addr_to_profile(profile_ipv6, mac_addr) + + expected_calls = [ + call(profile_ipv4, mac_addr), # call with ipv4 profile id + call(profile_ipv6, mac_addr), # call with ipv6 profile id + ] + db.rdb.update_mac_of_profile.assert_has_calls( + expected_calls, any_order=True + ) def test_get_the_other_ip_version():