Skip to content

Commit

Permalink
Fix zmq inference (#1800)
Browse files Browse the repository at this point in the history
* Ensure that we always pass in the zmq_port dict to LossViewer

* Ensure zmq_ports has correct keys inside LossViewer

* Use specified controller and publish ports for first attempted addresses

* Add test for ports being set in LossViewer

* Add max attempts to find unused port

* Fix find free port loop and add for controller port also

* Improve code readablility and reuse

* Improve error message when unable to find free port
  • Loading branch information
roomrys authored Jun 11, 2024
1 parent 834c68a commit bbe1246
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 26 deletions.
10 changes: 4 additions & 6 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,13 +605,11 @@ def run_gui_training(
from sleap.gui.widgets.monitor import LossViewer
from sleap.gui.widgets.imagedir import QtImageDirectoryWidget

if all(key in inference_params for key in ["controller_port", "publish_port"]):
zmq_ports = {
"controller_port": inference_params["controller_port"],
"publish_port": inference_params["publish_port"],
}
zmq_ports = dict()
zmq_ports["controller_port"] = inference_params.get("controller_port", 9000)
zmq_ports["publish_port"] = inference_params.get("publish_port", 9001)

# open training monitor window
# Open training monitor window
win = LossViewer(zmq_ports=zmq_ports)

# Reassign the values in the inference parameters in case the ports were changed
Expand Down
55 changes: 37 additions & 18 deletions sleap/gui/widgets/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(
self.stop_button = None
self.cancel_button = None
self.canceled = False

# Set up ZMQ ports for communication.
zmq_ports = zmq_ports or dict()
zmq_ports["publish_port"] = zmq_ports.get("publish_port", 9001)
zmq_ports["controller_port"] = zmq_ports.get("controller_port", 9000)
self.zmq_ports = zmq_ports

self.batches_to_show = -1 # -1 to show all
Expand Down Expand Up @@ -305,37 +310,51 @@ def setup_zmq(self, zmq_context: Optional[zmq.Context] = None):
self.ctx_given = zmq_context is not None
self.ctx = zmq.Context() if zmq_context is None else zmq_context

# Default publish and control address
controller_address = "tcp://127.0.0.1:9000"
publish_address = "tcp://127.0.0.1:9001"

# Progress monitoring, SUBSCRIBER
self.sub = self.ctx.socket(zmq.SUB)
self.sub.subscribe("")

if self.zmq_ports and not is_port_free(
port=self.zmq_ports["publish_port"], zmq_context=self.ctx
):
self.zmq_ports["publish_port"] = select_zmq_port(zmq_context=self.ctx)
publish_address = "tcp://127.0.0.1:" + str(self.zmq_ports["publish_port"])
def find_free_port(port: int, zmq_context: zmq.Context):
"""Find free port to bind to.
Args:
port: The port to start searching from.
zmq_context: The ZMQ context to use.
Returns:
The free port.
"""
attempts = 0
max_attempts = 10
while not is_port_free(port=port, zmq_context=zmq_context):
if attempts >= max_attempts:
raise RuntimeError(
f"Could not find free port to display training progress after "
f"{max_attempts} attempts. Please check your network settings "
"or use the CLI `sleap-train` command."
)
port = select_zmq_port(zmq_context=self.ctx)
attempts += 1

return port

# Find a free port and bind to it.
self.zmq_ports["publish_port"] = find_free_port(
port=self.zmq_ports["publish_port"], zmq_context=self.ctx
)
publish_address = f"tcp://127.0.0.1:{self.zmq_ports['publish_port']}"
self.sub.bind(publish_address)

# Controller, PUBLISHER
self.zmq_ctrl = None
if self.show_controller:
self.zmq_ctrl = self.ctx.socket(zmq.PUB)

if self.zmq_ports and not is_port_free(
# Find a free port and bind to it.
self.zmq_ports["controller_port"] = find_free_port(
port=self.zmq_ports["controller_port"], zmq_context=self.ctx
):
self.zmq_ports["controller_port"] = select_zmq_port(
zmq_context=self.ctx
)
controller_address = "tcp://127.0.0.1:" + str(
self.zmq_ports["controller_port"]
)

)
controller_address = f"tcp://127.0.0.1:{self.zmq_ports['controller_port']}"
self.zmq_ctrl.bind(controller_address)

# Set timer to poll for messages.
Expand Down
21 changes: 19 additions & 2 deletions tests/gui/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from turtle import title
from sleap.gui.widgets.monitor import LossViewer
from sleap import TrainingJobConfig

Expand All @@ -12,6 +11,9 @@ def test_monitor_release(qtbot, min_centroid_model_path):
win.reset(what="Model Type", config=config)
assert win.config.optimization.early_stopping.plateau_patience == 10

# Ensure zmq port is set correctly
assert win.zmq_ports["controller_port"] == 9000
assert win.zmq_ports["publish_port"] == 9001
# Ensure all lines of update_runtime() are run error-free
win.is_running = True
win.t0 = 0
Expand All @@ -33,8 +35,12 @@ def test_monitor_release(qtbot, min_centroid_model_path):
win.close()

# Make sure the first monitor released its zmq socket
win2 = LossViewer()
controller_port = 9191
zmq_ports = dict(controller_port=controller_port)
win2 = LossViewer(zmq_ports=zmq_ports)
win2.show()
assert win2.zmq_ports["controller_port"] == controller_port
assert win2.zmq_ports["publish_port"] == 9001

# Make sure batches to show field is working correction

Expand All @@ -47,3 +53,14 @@ def test_monitor_release(qtbot, min_centroid_model_path):
assert win2.batches_to_show == 200

win2.close()

# Ensure zmq port is set correctly
controller_port = 9191
publish_port = 9101
zmq_ports = dict(controller_port=controller_port, publish_port=publish_port)
win3 = LossViewer(zmq_ports=zmq_ports)
win3.show()
assert win3.zmq_ports["controller_port"] == controller_port
assert win3.zmq_ports["publish_port"] == publish_port

win3.close()

0 comments on commit bbe1246

Please sign in to comment.