Skip to content

Commit

Permalink
Merge pull request #1112 from stratosphereips/alya/use-lru-cache
Browse files Browse the repository at this point in the history
Optimize MAC OUIs and malicious domains lookups
  • Loading branch information
AlyaGomaa authored Dec 10, 2024
2 parents 9bb16a5 + 34ecab9 commit f02645e
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 62 deletions.
27 changes: 15 additions & 12 deletions modules/ip_info/ip_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,30 +198,33 @@ def get_vendor_online(self, mac_addr):
):
return False

@staticmethod
@lru_cache(maxsize=700)
def _get_vendor_offline_cached(oui, mac_db_content):
"""
Static helper to perform the actual lookup based on OUI and cached content.
"""
for line in mac_db_content:
if oui in line:
line = json.loads(line)
return line["vendorName"]
return False

def get_vendor_offline(self, mac_addr, profileid):
"""
Gets vendor from Slips' offline database databases/macaddr-db.json
Gets vendor from Slips' offline database at databases/macaddr-db.json.
"""
if not hasattr(self, "mac_db"):
if not hasattr(self, "mac_db") or self.mac_db is None:
# when update manager is done updating the mac db, we should ask
# the db for all these pending queries
self.pending_mac_queries.put((mac_addr, profileid))
return False

oui = mac_addr[:8].upper()
# parse the mac db and search for this oui
self.mac_db.seek(0)
while True:
line = self.mac_db.readline()
if line == "":
# reached the end of file without finding the vendor
# set the vendor to unknown to avoid searching for it again
return False
mac_db_content = self.mac_db.readlines()

if oui in line:
line = json.loads(line)
return line["vendorName"]
return self._get_vendor_offline_cached(oui, tuple(mac_db_content))

def get_vendor(self, mac_addr: str, profileid: str) -> dict:
"""
Expand Down
22 changes: 0 additions & 22 deletions modules/threat_intelligence/threat_intelligence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,16 +1429,6 @@ def is_malicious_url(self, url, uid, timestamp, daddr, profileid, twid):
"""Determines if a URL is considered malicious by querying online threat
intelligence sources.
Parameters:
- url (str): The URL to check.
- uid (str): Unique identifier for the network flow.
- timestamp (str): Timestamp when the network flow occurred.
- daddr (str): Destination IP address in the network flow.
- profileid (str): Identifier of the profile associated
with the network flow.
- twid (str): Time window identifier for when the network
flow occurred.
Returns:
- None: The function does not return a value but triggers
evidence creation if the URL is found to be malicious.
Expand Down Expand Up @@ -1633,18 +1623,6 @@ def is_malicious_domain(
malicious, it records an evidence entry and marks the
domain in the database.
Parameters:
domain (str): The domain name to be evaluated for
malicious activity.
uid (str): Unique identifier of the network flow
associated with this domain query.
timestamp (str): Timestamp when the domain query
was observed.
profileid (str): Identifier of the network profile
that initiated the domain query.
twid (str): Time window identifier during which the
domain query occurred.
Returns:
bool: False if the domain is ignored or not found in the
offline threat intelligence data, indicating no further action
Expand Down
47 changes: 47 additions & 0 deletions slips_files/common/data_structures/trie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections import defaultdict
from typing import (
Dict,
Tuple,
Optional,
)


class TrieNode:
def __init__(self):
self.children = defaultdict(TrieNode)
self.is_end_of_word = False
self.domain_info = (
None # store associated domain information if needed
)


class Trie:
def __init__(self):
self.root = TrieNode()

def insert(self, domain: str, domain_info: str):
"""Insert a domain into the trie (using domain parts not chars)."""
node = self.root
parts = domain.split(".")[::-1] # reverse to handle subdomains
for part in parts:
node = node.children[part]
node.is_end_of_word = True
node.domain_info = domain_info

def search(self, domain: str) -> Tuple[bool, Optional[Dict[str, str]]]:
"""
Check if a domain or its subdomain exists in the trie
(using domain parts instead of characters).
Returns a tuple (found, domain_info).
"""
node = self.root
# reverse domain to handle subdomains
parts = domain.split(".")[::-1]
for part in parts:
if part not in node.children:
return False, None

node = node.children[part]
if node.is_end_of_word:
return True, node.domain_info
return False, None
2 changes: 2 additions & 0 deletions slips_files/core/database/redis_db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def __new__(
)

cls._instances[cls.redis_port] = super().__new__(cls)
super().__init__(cls)

# By default the slips internal time is
# 0 until we receive something
cls.set_slips_internal_time(0)
Expand Down
107 changes: 79 additions & 28 deletions slips_files/core/database/redis_db/ioc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
Optional,
)

from slips_files.common.data_structures.trie import Trie

# for future developers, remember to invalidate_trie_cache() on every
# change to the self.constants.IOC_DOMAINS key or slips will keep using an
# invalid cache to lookup malicious domains


class IoCHandler:
"""
Expand All @@ -17,6 +23,38 @@ class IoCHandler:

name = "DB"

def __init__(self):
# used for faster domain lookups
self.trie = None
self.is_trie_cached = False

def _build_trie(self):
"""Retrieve domains from Redis and construct the trie."""
self.trie = Trie()
ioc_domains: Dict[str, str] = self.rcache.hgetall(
self.constants.IOC_DOMAINS
)
for domain, domain_info in ioc_domains.items():
domain: str
domain_info: str
# domain_info is something like this
# {"description": "['hack''malware''phishing']",
# "source": "OCD-Datalake-russia-ukraine_IOCs-ALL.csv",
# "threat_level": "medium",
# "tags": ["Russia-UkraineIoCs"]}

# store parsed domain info
self.trie.insert(domain, json.loads(domain_info))
self.is_trie_cached = True

def _invalidate_trie_cache(self):
"""
Invalidate the trie cache.
used whenever IOC_DOMAINS key is updated.
"""
self.trie = None
self.is_trie_cached = False

def set_loaded_ti_files(self, number_of_loaded_files: int):
"""
Stores the number of successfully loaded TI files
Expand All @@ -43,6 +81,7 @@ def delete_feed_entries(self, url: str):
if feed_to_delete in domain_description["source"]:
# this entry has the given feed as source, delete it
self.rcache.hdel(self.constants.IOC_DOMAINS, domain)
self._invalidate_trie_cache()

# get all IPs that are read from TI files in our db
ioc_ips = self.rcache.hgetall(self.constants.IOC_IPS)
Expand Down Expand Up @@ -139,6 +178,7 @@ def delete_domains_from_ioc_domains(self, domains: List[str]):
Delete old domains from IoC
"""
self.rcache.hdel(self.constants.IOC_DOMAINS, *domains)
self._invalidate_trie_cache()

def add_ips_to_ioc(self, ips_and_description: Dict[str, str]) -> None:
"""
Expand All @@ -164,6 +204,7 @@ def add_domains_to_ioc(self, domains_and_description: dict) -> None:
self.rcache.hmset(
self.constants.IOC_DOMAINS, domains_and_description
)
self._invalidate_trie_cache()

def add_ip_range_to_ioc(self, malicious_ip_ranges: dict) -> None:
"""
Expand Down Expand Up @@ -239,43 +280,53 @@ def is_blacklisted_ssl(self, sha1):
info = self.rcache.hmget(self.constants.IOC_SSL, sha1)[0]
return False if info is None else info

def _match_exact_domain(self, domain: str) -> Optional[Dict[str, str]]:
"""checks if the given domain is blacklisted.
checks only the exact given domain, no subdomains"""
domain_description = self.rcache.hget(
self.constants.IOC_DOMAINS, domain
)
if not domain_description:
return
return json.loads(domain_description)

def _match_subdomain(self, domain: str) -> Optional[Dict[str, str]]:
"""
Checks if we have any blacklisted domain that is a part of the
given domain
Uses a cached trie for optimization.
"""
# the goal here is we dont retrieve that huge amount of domains
# from the db on every domain lookup
# so we retrieve once, put em in a trie (aka cache them in memory),
# keep using them from that data structure until a new domain is
# added to the db, when that happens we invalidate the cache,
# rebuild the trie, and keep using it from there.
if not self.is_trie_cached:
self._build_trie()

found, domain_info = self.trie.search(domain)
if found:
return domain_info

def is_blacklisted_domain(
self, domain: str
) -> Tuple[Dict[str, str], bool]:
) -> Union[Tuple[Dict[str, str], bool], bool]:
"""
Search in the dB of malicious domains and return a
description if we found a match
Check if the given domain or its subdomain is blacklisted.
returns a tuple (description, is_subdomain)
description: description of the subdomain if found
bool: True if we found a match for exactly the given
domain False if we matched a subdomain
"""
domain_description = self.rcache.hget(
self.constants.IOC_DOMAINS, domain
)
is_subdomain = False
if domain_description:
return json.loads(domain_description), is_subdomain
if match := self._match_exact_domain(domain):
is_subdomain = False
return match, is_subdomain

# try to match subdomain
ioc_domains: Dict[str, Dict[str, str]] = self.rcache.hgetall(
self.constants.IOC_DOMAINS
)
for malicious_domain, domain_info in ioc_domains.items():
malicious_domain: str
domain_info: str
# something like this
# {"description": "['hack''malware''phishing']",
# "source": "OCD-Datalake-russia-ukraine_IOCs-ALL.csv",
# "threat_level": "medium",
# "tags": ["Russia-UkraineIoCs"]}
domain_info: Dict[str, str] = json.loads(domain_info)
# if the we contacted images.google.com and we have
# google.com in our blacklists, we find a match
if malicious_domain in domain:
is_subdomain = True
return domain_info, is_subdomain
return False, is_subdomain
if match := self._match_subdomain(domain):
is_subdomain = True
return match, is_subdomain
return False, False

def get_all_blacklisted_ip_ranges(self) -> dict:
"""
Expand Down

0 comments on commit f02645e

Please sign in to comment.