Skip to content

Commit

Permalink
Automatically get group id by pid, and change some default value
Browse files Browse the repository at this point in the history
When using killpg, we need to call with group id instead of pid

1. Get group id by pid for `killpg`
2. Modify `kill` and `terminate` 's default group to `False` for better
   compatibility.
  • Loading branch information
karajan1001 committed Dec 9, 2022
1 parent b5cc7d4 commit 643ace0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
15 changes: 8 additions & 7 deletions src/dvc_task/proc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run_signature(
immutable=immutable,
)

def send_signal(self, name: str, sig: int, group: bool = True):
def send_signal(self, name: str, sig: int, group: bool):
"""Send `signal` to the specified named process."""
try:
process_info = self[name]
Expand All @@ -134,12 +134,13 @@ def handle_closed_process():

if process_info.returncode is None:
try:
pid = (
os.getpgid(process_info.pid) if group else process_info.pid
)
if sys.platform != "win32" and group:
os.killpg( # pylint: disable=no-member
process_info.pid, sig
)
os.killpg(pid, sig) # pylint: disable=no-member
else:
os.kill(process_info.pid, sig)
os.kill(pid, sig)
except ProcessLookupError:
handle_closed_process()
raise
Expand All @@ -156,11 +157,11 @@ def interrupt(self, name: str, group: bool = True):
"""Send interrupt signal to specified named process"""
self.send_signal(name, signal.SIGINT, group)

def terminate(self, name: str, group: bool = True):
def terminate(self, name: str, group: bool = False):
"""Terminate the specified named process."""
self.send_signal(name, signal.SIGTERM, group)

def kill(self, name: str, group: bool = True):
def kill(self, name: str, group: bool = False):
"""Kill the specified named process."""
if sys.platform == "win32":
self.send_signal(name, signal.SIGTERM, group)
Expand Down
26 changes: 16 additions & 10 deletions tests/proc/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@ def test_send_signal(
process_manager.send_signal(running_process, signal.SIGTERM, False)
mock_kill.assert_called_once_with(PID_RUNNING, signal.SIGTERM)

gid = 100
mocker.patch("os.getpgid", return_value=gid)
if sys.platform != "win32":
mock_killpg = mocker.patch("os.killpg")
process_manager.send_signal(running_process, signal.SIGINT, True)
mock_killpg.assert_called_once_with(PID_RUNNING, signal.SIGINT)
else:
mock_killpg = mocker.patch("os.kill")

process_manager.send_signal(running_process, signal.SIGINT, True)
mock_killpg.assert_called_once_with(gid, signal.SIGINT)

mock_kill.reset_mock()
with pytest.raises(ProcessLookupError):
process_manager.send_signal(finished_process, signal.SIGTERM)
process_manager.send_signal(finished_process, signal.SIGTERM, False)
mock_kill.assert_not_called()

if sys.platform == "win32":
Expand All @@ -59,11 +64,11 @@ def side_effect(*args):

mocker.patch("os.kill", side_effect=side_effect)
with pytest.raises(ProcessLookupError):
process_manager.send_signal(running_process, signal.SIGTERM)
process_manager.send_signal(running_process, signal.SIGTERM, False)
assert process_manager[running_process].returncode == -1

with pytest.raises(ProcessLookupError):
process_manager.send_signal("nonexists", signal.SIGTERM)
process_manager.send_signal("nonexists", signal.SIGTERM, False)


if sys.platform == "win32":
Expand All @@ -73,25 +78,26 @@ def side_effect(*args):


@pytest.mark.parametrize(
"method, sig",
"method, sig, group",
[
("kill", SIGKILL),
("terminate", signal.SIGTERM),
("interrupt", signal.SIGINT),
("kill", SIGKILL, False),
("terminate", signal.SIGTERM, False),
("interrupt", signal.SIGINT, True),
],
)
def test_kill_commands(
mocker: MockerFixture,
process_manager: ProcessManager,
method: str,
sig: signal.Signals,
group: bool,
):
"""Test shortcut for different signals."""
name = "process"
mock_kill = mocker.patch.object(process_manager, "send_signal")
func = getattr(process_manager, method)
func(name)
mock_kill.assert_called_once_with(name, sig, True)
mock_kill.assert_called_once_with(name, sig, group)


def test_remove(
Expand Down

0 comments on commit 643ace0

Please sign in to comment.