diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 882776cf5c..d3fa522619 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -331,6 +331,9 @@ def __init__(self, interchange_launch_cmd = DEFAULT_INTERCHANGE_LAUNCH_CMD self.interchange_launch_cmd = interchange_launch_cmd + self._result_queue_thread_exit = threading.Event() + self._result_queue_thread: Optional[threading.Thread] = None + radio_mode = "htex" enable_mpi_mode: bool = False mpi_launcher: str = "mpiexec" @@ -455,9 +458,11 @@ def _result_queue_worker(self): """ logger.debug("Result queue worker starting") - while not self.bad_state_is_set: + while not self.bad_state_is_set and not self._result_queue_thread_exit.is_set(): try: - msgs = self.incoming_q.get() + msgs = self.incoming_q.get(timeout_ms=self.poll_period) + if msgs is None: # timeout + continue except IOError as e: logger.exception("Caught broken queue with exception code {}: {}".format(e.errno, e)) @@ -515,6 +520,8 @@ def _result_queue_worker(self): else: raise BadMessage("Message received with unknown type {}".format(msg['type'])) + logger.info("Closing result ZMQ pipe") + self.incoming_q.close() logger.info("Result queue worker finished") def _start_local_interchange_process(self) -> None: @@ -817,6 +824,8 @@ def shutdown(self, timeout: float = 10.0): logger.info("Attempting HighThroughputExecutor shutdown") + logger.info("Terminating interchange and result queue thread") + self._result_queue_thread_exit.set() self.interchange_proc.terminate() try: self.interchange_proc.wait(timeout=timeout) @@ -841,6 +850,10 @@ def shutdown(self, timeout: float = 10.0): logger.info("Closing command client") self.command_client.close() + logger.info("Waiting for result queue thread exit") + if self._result_queue_thread: + self._result_queue_thread.join() + logger.info("Finished HighThroughputExecutor shutdown attempt") def get_usage_information(self): diff --git a/parsl/executors/high_throughput/zmq_pipes.py b/parsl/executors/high_throughput/zmq_pipes.py index 54ed8c1da9..a7278cf067 100644 --- a/parsl/executors/high_throughput/zmq_pipes.py +++ b/parsl/executors/high_throughput/zmq_pipes.py @@ -206,12 +206,21 @@ def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None): self.port = self.results_receiver.bind_to_random_port(tcp_url(ip_address), min_port=port_range[0], max_port=port_range[1]) + self.poller = zmq.Poller() + self.poller.register(self.results_receiver, zmq.POLLIN) - def get(self): + def get(self, timeout_ms=None): + """Get a message from the queue, returning None if timeout expires + without a message. timeout is measured in milliseconds. + """ logger.debug("Waiting for ResultsIncoming message") - m = self.results_receiver.recv_multipart() - logger.debug("Received ResultsIncoming message") - return m + socks = dict(self.poller.poll(timeout=timeout_ms)) + if self.results_receiver in socks and socks[self.results_receiver] == zmq.POLLIN: + m = self.results_receiver.recv_multipart() + logger.debug("Received ResultsIncoming message") + return m + else: + return None def close(self): self.results_receiver.close()