diff --git a/tests/module_factory.py b/tests/module_factory.py index b31647470..dcbb5e0a7 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -188,8 +188,8 @@ def create_ip_info_obj(self, mock_db): ip_info.print = do_nothing return ip_info - def create_asn_obj(self, db): - return ASN(db) + def create_asn_obj(self, mock_db): + return ASN(mock_db) def create_leak_detector_obj(self, mock_db): # this file will be used for storing the module output @@ -212,9 +212,9 @@ def create_leak_detector_obj(self, mock_db): return leak_detector - def create_profiler_obj(self): + def create_profiler_obj(self, mock_db): dummy_semaphore = Semaphore(0) - profilerProcess = Profiler( + profiler = Profiler( self.logger, 'output/', 6379, @@ -225,9 +225,10 @@ def create_profiler_obj(self): ) # override the self.print function to avoid broken pipes - profilerProcess.print = do_nothing - profilerProcess.whitelist_path = 'tests/test_whitelist.conf' - return profilerProcess + profiler.print = do_nothing + profiler.whitelist_path = 'tests/test_whitelist.conf' + profiler.db = mock_db + return profiler def create_redis_manager_obj(self, main): return RedisManager(main) @@ -244,7 +245,7 @@ def create_threatintel_obj(self, mock_db): 'dummy_output_dir', 6379, self.dummy_termination_event) - threatintel.db.rdb = mock_db + threatintel.db = mock_db # override the self.print function to avoid broken pipes threatintel.print = do_nothing diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 693eb43ae..92f121a82 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -18,7 +18,7 @@ def test_define_separator_suricata(file, input_type, expected_value, mock_db ): - profilerProcess = ModuleFactory().create_profiler_obj() + profilerProcess = ModuleFactory().create_profiler_obj(mock_db) with open(file) as f: while True: sample_flow = f.readline().replace('\n', '') @@ -40,7 +40,7 @@ def test_define_separator_suricata(file, input_type, expected_value, def test_define_separator_zeek_tab(file, input_type, expected_value, mock_db ): - profilerProcess = ModuleFactory().create_profiler_obj() + profilerProcess = ModuleFactory().create_profiler_obj(mock_db) with open(file) as f: while True: sample_flow = f.readline().replace('\n', '') @@ -66,7 +66,7 @@ def test_define_separator_zeek_dict(file, input_type, expected_value, :param input_type: as determined by slips.py """ - profilerProcess = ModuleFactory().create_profiler_obj() + profilerProcess = ModuleFactory().create_profiler_obj(mock_db) with open(file) as f: sample_flow = f.readline().replace('\n', '') @@ -98,7 +98,7 @@ def test_define_separator_nfdump(nfdump_file, else: break - profilerProcess = ModuleFactory().create_profiler_obj() + profilerProcess = ModuleFactory().create_profiler_obj(mock_db) sample_flow = { 'data': nfdump_line, } @@ -132,7 +132,7 @@ def test_define_separator_nfdump(nfdump_file, # line = f.readline() # if line.startswith('#fields'): # break -# profilerProcess = ModuleFactory().create_profiler_obj() +# profilerProcess = ModuleFactory().create_profiler_obj(mock_db) # line = {'data': line} # profilerProcess.separator = separator # assert profilerProcess.define_columns(line) == expected_value @@ -152,8 +152,8 @@ def test_define_separator_nfdump(nfdump_file, # ('dataset/test9-mixed-zeek-dir/files.log', 'files.log'), ], ) -def test_process_line(file, flow_type): - profiler = ModuleFactory().create_profiler_obj() +def test_process_line(file, flow_type, mock_db): + profiler = ModuleFactory().create_profiler_obj(mock_db) # we're testing another functionality here profiler.whitelist.is_whitelisted_flow = do_nothing profiler.input_type = 'zeek' @@ -194,7 +194,7 @@ def test_process_line(file, flow_type): assert added_flow is not None def test_get_rev_profile(mock_db): - profiler = ModuleFactory().create_profiler_obj() + profiler = ModuleFactory().create_profiler_obj(mock_db) profiler.flow = Conn( '1.0', '1234', @@ -209,13 +209,12 @@ def test_get_rev_profile(mock_db): '','', 'Established','' ) - profiler.daddr_as_obj = ipaddress.ip_address(profiler.flow.daddr) mock_db.get_profileid_from_ip.return_value = None mock_db.get_timewindow.return_value = 'timewindow1' assert profiler.get_rev_profile() == ('profile_8.8.8.8', 'timewindow1') -def test_get_rev_profile_no_daddr(flow): - profiler = ModuleFactory().create_profiler_obj() +def test_get_rev_profile_no_daddr(flow, mock_db): + profiler = ModuleFactory().create_profiler_obj(mock_db) profiler.flow = flow profiler.flow.daddr = None profiler.daddr_as_obj = None