Skip to content

Commit

Permalink
unit tests: use the mock db in the profiler obj
Browse files Browse the repository at this point in the history
  • Loading branch information
AlyaGomaa committed Feb 1, 2024
1 parent 657b18f commit ff72c3e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
17 changes: 9 additions & 8 deletions tests/module_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def create_ip_info_obj(self, mock_db):
ip_info.print = do_nothing
return ip_info

def create_asn_obj(self, db):
return ASN(db)
def create_asn_obj(self, mock_db):
return ASN(mock_db)

def create_leak_detector_obj(self, mock_db):
# this file will be used for storing the module output
Expand All @@ -212,9 +212,9 @@ def create_leak_detector_obj(self, mock_db):
return leak_detector


def create_profiler_obj(self):
def create_profiler_obj(self, mock_db):
dummy_semaphore = Semaphore(0)
profilerProcess = Profiler(
profiler = Profiler(
self.logger,
'output/',
6379,
Expand All @@ -225,9 +225,10 @@ def create_profiler_obj(self):
)

# override the self.print function to avoid broken pipes
profilerProcess.print = do_nothing
profilerProcess.whitelist_path = 'tests/test_whitelist.conf'
return profilerProcess
profiler.print = do_nothing
profiler.whitelist_path = 'tests/test_whitelist.conf'
profiler.db = mock_db
return profiler

def create_redis_manager_obj(self, main):
return RedisManager(main)
Expand All @@ -244,7 +245,7 @@ def create_threatintel_obj(self, mock_db):
'dummy_output_dir',
6379,
self.dummy_termination_event)
threatintel.db.rdb = mock_db
threatintel.db = mock_db

# override the self.print function to avoid broken pipes
threatintel.print = do_nothing
Expand Down
21 changes: 10 additions & 11 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_define_separator_suricata(file, input_type, expected_value,
mock_db
):
profilerProcess = ModuleFactory().create_profiler_obj()
profilerProcess = ModuleFactory().create_profiler_obj(mock_db)
with open(file) as f:
while True:
sample_flow = f.readline().replace('\n', '')
Expand All @@ -40,7 +40,7 @@ def test_define_separator_suricata(file, input_type, expected_value,
def test_define_separator_zeek_tab(file, input_type, expected_value,
mock_db
):
profilerProcess = ModuleFactory().create_profiler_obj()
profilerProcess = ModuleFactory().create_profiler_obj(mock_db)
with open(file) as f:
while True:
sample_flow = f.readline().replace('\n', '')
Expand All @@ -66,7 +66,7 @@ def test_define_separator_zeek_dict(file, input_type, expected_value,
:param input_type: as determined by slips.py
"""

profilerProcess = ModuleFactory().create_profiler_obj()
profilerProcess = ModuleFactory().create_profiler_obj(mock_db)
with open(file) as f:
sample_flow = f.readline().replace('\n', '')

Expand Down Expand Up @@ -98,7 +98,7 @@ def test_define_separator_nfdump(nfdump_file,
else:
break

profilerProcess = ModuleFactory().create_profiler_obj()
profilerProcess = ModuleFactory().create_profiler_obj(mock_db)
sample_flow = {
'data': nfdump_line,
}
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_define_separator_nfdump(nfdump_file,
# line = f.readline()
# if line.startswith('#fields'):
# break
# profilerProcess = ModuleFactory().create_profiler_obj()
# profilerProcess = ModuleFactory().create_profiler_obj(mock_db)
# line = {'data': line}
# profilerProcess.separator = separator
# assert profilerProcess.define_columns(line) == expected_value
Expand All @@ -152,8 +152,8 @@ def test_define_separator_nfdump(nfdump_file,
# ('dataset/test9-mixed-zeek-dir/files.log', 'files.log'),
],
)
def test_process_line(file, flow_type):
profiler = ModuleFactory().create_profiler_obj()
def test_process_line(file, flow_type, mock_db):
profiler = ModuleFactory().create_profiler_obj(mock_db)
# we're testing another functionality here
profiler.whitelist.is_whitelisted_flow = do_nothing
profiler.input_type = 'zeek'
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_process_line(file, flow_type):
assert added_flow is not None

def test_get_rev_profile(mock_db):
profiler = ModuleFactory().create_profiler_obj()
profiler = ModuleFactory().create_profiler_obj(mock_db)
profiler.flow = Conn(
'1.0',
'1234',
Expand All @@ -209,13 +209,12 @@ def test_get_rev_profile(mock_db):
'','',
'Established',''
)
profiler.daddr_as_obj = ipaddress.ip_address(profiler.flow.daddr)
mock_db.get_profileid_from_ip.return_value = None
mock_db.get_timewindow.return_value = 'timewindow1'
assert profiler.get_rev_profile() == ('profile_8.8.8.8', 'timewindow1')

def test_get_rev_profile_no_daddr(flow):
profiler = ModuleFactory().create_profiler_obj()
def test_get_rev_profile_no_daddr(flow, mock_db):
profiler = ModuleFactory().create_profiler_obj(mock_db)
profiler.flow = flow
profiler.flow.daddr = None
profiler.daddr_as_obj = None
Expand Down

0 comments on commit ff72c3e

Please sign in to comment.