diff --git a/modules/flowalerts/flowalerts.py b/modules/flowalerts/flowalerts.py index fe75fbcd8..f55ba852d 100644 --- a/modules/flowalerts/flowalerts.py +++ b/modules/flowalerts/flowalerts.py @@ -564,7 +564,7 @@ def is_well_known_org(self, ip): """get the SNI, ASN, and rDNS of the IP to check if it belongs to a well-known org""" - ip_data = self.db.getIPData(ip) + ip_data = self.db.get_ip_info(ip) try: SNI = ip_data['SNI'] if type(SNI) == list: diff --git a/modules/ip_info/ip_info.py b/modules/ip_info/ip_info.py index 9ddad6efe..0e9ebc37d 100644 --- a/modules/ip_info/ip_info.py +++ b/modules/ip_info/ip_info.py @@ -537,7 +537,7 @@ def handle_new_ip(self, ip): if not ip_addr.is_multicast: # Do we have cached info about this ip in redis? # If yes, load it - cached_ip_info = self.db.getIPData(ip) + cached_ip_info = self.db.get_ip_info(ip) if not cached_ip_info: cached_ip_info = {} diff --git a/modules/p2ptrust/testing/test_p2p.py b/modules/p2ptrust/testing/test_p2p.py index 951078b33..059693775 100644 --- a/modules/p2ptrust/testing/test_p2p.py +++ b/modules/p2ptrust/testing/test_p2p.py @@ -173,13 +173,13 @@ def test_ip_info_changed(): def test_ip_data_save_to_redis(): print('Data in slips for ip 1.2.3.4') - print(__database__.getIPData('1.2.3.4')) + print(__database__.get_ip_info('1.2.3.4')) print('Update data') save_ip_report_to_db('1.2.3.4', 1, 0.4, 0.4) print('Data in slips for ip 1.2.3.4') - print(__database__.getIPData('1.2.3.4')) + print(__database__.get_ip_info('1.2.3.4')) def test_inputs(): diff --git a/modules/p2ptrust/utils/utils.py b/modules/p2ptrust/utils/utils.py index 95257dc5b..d55a8ca40 100644 --- a/modules/p2ptrust/utils/utils.py +++ b/modules/p2ptrust/utils/utils.py @@ -96,7 +96,7 @@ def get_ip_info_from_slips(ip_address: str, db) -> (float, float): """ # poll new info from redis - ip_info = db.getIPData(ip_address) + ip_info = db.get_ip_info(ip_address) # There is a bug in the database where sometimes False is returned when key is not found. Correctly, dictionary # should be always returned, even if it is empty. This check cannot be simplified to `if not ip_info`, because I diff --git a/modules/threat_intelligence/threat_intelligence.py b/modules/threat_intelligence/threat_intelligence.py index 2275db312..aebea1f44 100644 --- a/modules/threat_intelligence/threat_intelligence.py +++ b/modules/threat_intelligence/threat_intelligence.py @@ -769,7 +769,7 @@ def ip_has_blacklisted_ASN( Check if this ip has any of our blacklisted ASNs. blacklisted asns are taken from own_malicious_iocs.csv """ - ip_info = self.db.getIPData(ip) + ip_info = self.db.get_ip_info(ip) if not ip_info: # we dont know the asn of this ip return diff --git a/modules/virustotal/virustotal.py b/modules/virustotal/virustotal.py index 3e422f73d..3ba03723d 100644 --- a/modules/virustotal/virustotal.py +++ b/modules/virustotal/virustotal.py @@ -204,7 +204,7 @@ def API_calls_thread(self): ioc = self.api_call_queue.pop(0) ioc_type = self.get_ioc_type(ioc) if ioc_type == 'ip': - cached_data = self.db.getIPData(ioc) + cached_data = self.db.get_ip_info(ioc) # return an IPv4Address or IPv6Address object depending on the IP address passed as argument. ip_addr = ipaddress.ip_address(ioc) # if VT data of this IP (not multicast) is not in the IPInfo, ask VT. @@ -531,7 +531,7 @@ def main(self): for key, value in flow.items(): flow_data = json.loads(value) ip = flow_data['daddr'] - cached_data = self.db.getIPData(ip) + cached_data = self.db.get_ip_info(ip) if not cached_data: cached_data = {} diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index c6ef9dd60..a601010b5 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -78,7 +78,7 @@ def get_cidr_of_ip(self, ip): return network_range - def threat_level_to_string(self, threat_level: float): + def threat_level_to_string(self, threat_level: float) -> str: for str_lvl, int_value in self.threat_levels.items(): if threat_level <= int_value: return str_lvl diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 662db15d0..80045f3cb 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -67,8 +67,8 @@ def get_message(self, *args, **kwargs): def print(self, *args, **kwargs): return self.rdb.print(*args, **kwargs) - def getIPData(self, *args, **kwargs): - return self.rdb.getIPData(*args, **kwargs) + def get_ip_info(self, *args, **kwargs): + return self.rdb.get_ip_info(*args, **kwargs) def set_new_ip(self, *args, **kwargs): return self.rdb.set_new_ip(*args, **kwargs) @@ -831,6 +831,12 @@ def execute_query(self, *args, **kwargs): def get_pid_of(self, *args, **kwargs): return self.rdb.get_pid_of(*args, **kwargs) + def set_max_threat_level(self, *args, **kwargs): + return self.rdb.set_max_threat_level(*args, **kwargs) + + def update_max_threat_level(self, *args, **kwargs): + return self.rdb.update_max_threat_level(*args, **kwargs) + def get_name_of_module_at(self, *args, **kwargs): return self.rdb.get_name_of_module_at(*args, **kwargs) diff --git a/slips_files/core/database/redis_db/alert_handler.py b/slips_files/core/database/redis_db/alert_handler.py index ba940a13a..c4919ef40 100644 --- a/slips_files/core/database/redis_db/alert_handler.py +++ b/slips_files/core/database/redis_db/alert_handler.py @@ -1,6 +1,7 @@ import time import json from uuid import uuid4 +from typing import List, Tuple from slips_files.common.slips_utils import utils @@ -398,7 +399,42 @@ def getEvidenceForTW(self, profileid, twid): evidence = self.remove_whitelisted_evidence(evidence) return evidence - def update_threat_level(self, profileid: str, threat_level: str, confidence: int): + def set_max_threat_level(self, profileid: str, threat_level: str): + self.r.hset(profileid, 'max_threat_level', threat_level) + + def update_max_threat_level( + self, profileid: str, threat_level: str + ) -> float: + """ + given the current threat level of a profileid, this method sets the + max_threaty_level value to the given val if that max is less than + the given + :returns: the numerical val of the max threat level + """ + threat_level_float = utils.threat_levels[threat_level] + + old_max_threat_level: str = self.r.hget( + profileid, + 'max_threat_level' + ) + + if not old_max_threat_level: + # first time setting max tl + self.set_max_threat_level(profileid, threat_level) + return threat_level_float + + old_max_threat_level_float = utils.threat_levels[old_max_threat_level] + + if old_max_threat_level_float < threat_level_float: + self.set_max_threat_level(profileid, threat_level) + return threat_level_float + + return old_max_threat_level_float + + + def update_threat_level( + self, profileid: str, threat_level: str, confidence: int + ): """ Update the threat level of a certain profile Updates the profileid key and the IPsInfo key with the @@ -407,49 +443,62 @@ def update_threat_level(self, profileid: str, threat_level: str, confidence: int """ self.r.hset(profileid, 'threat_level', threat_level) - now = time.time() - now = utils.convert_format(now, utils.alerts_format) - # keep track of old threat levels + + now = utils.convert_format(time.time(), utils.alerts_format) confidence = f'confidence: {confidence}' - past_threat_levels = self.r.hget(profileid, 'past_threat_levels') + # this is what we'll be storing in the db, tl, ts, and confidence threat_level_data = (threat_level, now, confidence) + + past_threat_levels: List[Tuple] = self.r.hget( + profileid, + 'past_threat_levels' + ) if past_threat_levels: - # get the lists of ts and past threat levels + # get the list of ts and past threat levels past_threat_levels = json.loads(past_threat_levels) - latest_threat_level, latest_ts, latest_confidence = past_threat_levels[-1] + + latest: Tuple = past_threat_levels[-1] + latest_threat_level: str = latest[0] + latest_confidence: str = latest[2] + if ( latest_threat_level == threat_level and latest_confidence == confidence ): - # if the past threat level and confidence are the same as the ones we wanna store, + # if the past threat level and confidence + # are the same as the ones we wanna store, # replace the timestamp only past_threat_levels[-1] = threat_level_data + # dont change the old max tl else: # add this threat level to the list of past threat levels past_threat_levels.append(threat_level_data) else: # first time setting a threat level for this profile past_threat_levels = [threat_level_data] - # threat_levels_update_time = [now] past_threat_levels = json.dumps(past_threat_levels) self.r.hset(profileid, 'past_threat_levels', past_threat_levels) + max_threat_lvl: float = self.update_max_threat_level( + profileid, threat_level + ) + + score_confidence = { + # get the numerical value of this threat level + 'score': max_threat_lvl, + 'confidence': confidence + } # set the score and confidence of the given ip in the db # when it causes an evidence # these 2 values will be needed when sharing with peers ip = profileid.split('_')[-1] - # get the numerical value of this threat level - score = utils.threat_levels[threat_level.lower()] - score_confidence = { - 'score': score, - 'confidence': confidence - } - if cached_ip_data := self.getIPData(ip): - # append the score and conf. to the already existing data - cached_ip_data.update(score_confidence) - self.rcache.hset('IPsInfo', ip, json.dumps(cached_ip_data)) - else: - self.rcache.hset('IPsInfo', ip, json.dumps(score_confidence)) + + if cached_ip_info := self.get_ip_info(ip): + # append the score and confidence to the already existing data + cached_ip_info.update(score_confidence) + score_confidence = cached_ip_info + + self.rcache.hset('IPsInfo', ip, json.dumps(score_confidence)) diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 2d22349f6..76f237d7c 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -377,16 +377,10 @@ def print(self, text, verbose=1, debug=0): } ) - - - def getIPData(self, ip: str) -> dict: + def get_ip_info(self, ip: str) -> dict: """ - Return information about this IP from IPsInfo - Returns a dictionary or False if there is no IP in the database - We need to separate these three cases: - 1- IP is in the DB without data. Return empty dict. - 2- IP is in the DB with data. Return dict. - 3- IP is not in the DB. Return False + Return information about this IP from IPsInfo key + :return: a dictionary or False if there is no IP in the database """ data = self.rcache.hget('IPsInfo', ip) return json.loads(data) if data else False @@ -400,7 +394,7 @@ def set_new_ip(self, ip: str): accessed as str, it is automatically converted to str """ - data = self.getIPData(ip) + data = self.get_ip_info(ip) if data is False: # If there is no data about this IP # Set this IP for the first time in the IPsInfo @@ -564,7 +558,7 @@ def setInfoForIPs(self, ip: str, to_store: dict): overwrite it """ # Get the previous info already stored - cached_ip_info = self.getIPData(ip) + cached_ip_info = self.get_ip_info(ip) if cached_ip_info is False: # This IP is not in the dictionary, add it first: self.set_new_ip(ip) @@ -940,7 +934,7 @@ def get_ip_identification(self, ip: str, get_ti_data=True): on the data stored so far :param get_ti_data: do we want to get info about this IP from out TI lists? """ - current_data = self.getIPData(ip) + current_data = self.get_ip_info(ip) identification = '' if current_data: if 'asn' in current_data: diff --git a/slips_files/core/database/redis_db/ioc_handler.py b/slips_files/core/database/redis_db/ioc_handler.py index 9a128bd42..858c4f534 100644 --- a/slips_files/core/database/redis_db/ioc_handler.py +++ b/slips_files/core/database/redis_db/ioc_handler.py @@ -471,7 +471,7 @@ def setInfoForURLs(self, url: str, urldata: dict): # This URL is not in the dictionary, add it first: self.setNewURL(url) # Now get the data, which should be empty, but just in case - data = self.getIPData(url) + data = self.get_ip_info(url) # empty dicts evaluate to False dict_has_keys = bool(data) if dict_has_keys: diff --git a/slips_files/core/database/redis_db/profile_handler.py b/slips_files/core/database/redis_db/profile_handler.py index ae495c2b4..a8114ef88 100644 --- a/slips_files/core/database/redis_db/profile_handler.py +++ b/slips_files/core/database/redis_db/profile_handler.py @@ -1128,7 +1128,7 @@ def add_out_ssl( flow.uid, flow.daddr, lookup=flow.server_name) # Save new server name in the IPInfo. There might be several server_name per IP. - if ipdata := self.getIPData(flow.daddr): + if ipdata := self.get_ip_info(flow.daddr): sni_ipdata = ipdata.get('SNI', []) else: sni_ipdata = [] diff --git a/slips_files/core/evidence.py b/slips_files/core/evidence.py index 7684aaa02..9bf67d732 100644 --- a/slips_files/core/evidence.py +++ b/slips_files/core/evidence.py @@ -208,7 +208,7 @@ def get_domains_of_flow(self, flow: dict): domains_to_check_dst = [] try: domains_to_check_src.append( - self.db.getIPData(flow['saddr']) + self.db.get_ip_info(flow['saddr']) .get('SNI', [{}])[0] .get('server_name') ) @@ -224,10 +224,8 @@ def get_domains_of_flow(self, flow: dict): except (KeyError, TypeError): pass try: - # self.print(f"IPData of dst IP {self.column_values['daddr']}: - # {self.db.getIPData(self.column_values['daddr'])}") domains_to_check_dst.append( - self.db.getIPData(flow['daddr']) + self.db.get_ip_info(flow['daddr']) .get('SNI', [{}])[0] .get('server_name') ) @@ -245,7 +243,8 @@ def show_popup(self, alert_to_log: str): os.system(f'{self.notify_cmd} "Slips" "{alert_to_log}"') elif platform.system() == 'Darwin': os.system( - f'osascript -e \'display notification "{alert_to_log}" with title "Slips"\' ' + f'osascript -e \'display notification "{alert_to_log}" ' + f'with title "Slips"\' ' ) diff --git a/slips_files/core/helpers/whitelist.py b/slips_files/core/helpers/whitelist.py index 1c4637950..2eace9f8a 100644 --- a/slips_files/core/helpers/whitelist.py +++ b/slips_files/core/helpers/whitelist.py @@ -54,7 +54,7 @@ def read_configuration(self): self.whitelist_path = conf.whitelist_path() def is_whitelisted_asn(self, ip, org): - ip_data = self.db.getIPData(ip) + ip_data = self.db.get_ip_info(ip) try: ip_asn = ip_data['asn']['asnorg'] org_asn = json.loads(self.db.get_org_info(org, 'asn')) @@ -537,7 +537,7 @@ def get_domains_of_flow(self, saddr, daddr): domains_to_check_src = [] domains_to_check_dst = [] try: - if ip_data := self.db.getIPData(saddr): + if ip_data := self.db.get_ip_info(saddr): if sni_info := ip_data.get('SNI', [{}])[0]: domains_to_check_src.append(sni_info.get('server_name', '')) except (KeyError, TypeError): @@ -550,7 +550,7 @@ def get_domains_of_flow(self, saddr, daddr): except (KeyError, TypeError): pass try: - if ip_data := self.db.getIPData(daddr): + if ip_data := self.db.get_ip_info(daddr): if sni_info := ip_data.get('SNI', [{}])[0]: domains_to_check_dst.append(sni_info.get('server_name')) except (KeyError, TypeError): @@ -626,7 +626,7 @@ def is_ip_asn_in_org_asn(self, ip, org): returns true if the ASN of the given IP is listed in the ASNs of the given org ASNs """ # Check if the IP in the content of the alert has ASN info in the db - ip_data = self.db.getIPData(ip) + ip_data = self.db.get_ip_info(ip) if not ip_data: return try: diff --git a/tests/test_database.py b/tests/test_database.py index 700da02ef..15aaebc29 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,5 +1,6 @@ from slips_files.common.slips_utils import utils from slips_files.core.flows.zeek import Conn +from slips_files.common.slips_utils import utils from tests.module_factory import ModuleFactory import redis import os @@ -221,3 +222,19 @@ def test_get_the_other_ip_version(): def test_add_tuple(tupleid: str, symbol, expected_direction, role, flow): db.add_tuple(profileid, twid, tupleid, symbol, role, flow) assert symbol[0] in db.r.hget(f'profile_{flow.saddr}_{twid}', expected_direction) + + +@pytest.mark.parametrize( + 'max_threat_level, cur_threat_level, expected_max', + [ + ('info', 'info', utils.threat_levels['info']), + ('critical', 'info', utils.threat_levels['critical']), + ('high', 'critical', utils.threat_levels['critical']), + ], +) +def test_update_max_threat_level( + max_threat_level, cur_threat_level, expected_max + ): + db.set_max_threat_level(profileid, max_threat_level) + assert db.update_max_threat_level( + profileid, cur_threat_level) == expected_max \ No newline at end of file