Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pr/164'
Browse files Browse the repository at this point in the history
* origin/pr/164:
  Avoid closing stderr when MSG_DATA_EXIT_CODE is received
  Send EOF whenever closing stdout
  Send EOF on stdout when exiting due to EOF on stdin
  Always send EOF on stderr
  • Loading branch information
marmarek committed Jun 14, 2024
2 parents bab74d5 + 2b71adb commit d123700
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 79 deletions.
14 changes: 10 additions & 4 deletions daemon/qrexec-daemon-common.c
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,15 @@ int prepare_local_fds(struct qrexec_parsed_command *command, struct buffer *stdi
// See also qrexec-agent/qrexec-agent-data.c
static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit_code)
{
struct msg_header hdr = {
.type = MSG_DATA_STDOUT,
.len = 0,
const struct msg_header hdr[2] = {
{
.type = MSG_DATA_STDERR,
.len = 0,
},
{
.type = MSG_DATA_STDOUT,
.len = 0,
},
};

LOG(ERROR, "failed to spawn process, exiting");
Expand All @@ -404,7 +410,7 @@ static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit
* when we support sockets as a local process.
*/
if (is_service) {
libvchan_send(data_vchan, &hdr, sizeof(hdr));
libvchan_send(data_vchan, hdr, sizeof(hdr));
send_exit_code(data_vchan, exit_code);
}
}
Expand Down
20 changes: 18 additions & 2 deletions libqrexec/process_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ int qrexec_process_io(const struct process_io_request *req,
struct timespec normal_timeout = { 10, 0 };
struct prefix_data empty = { 0, 0 }, prefix = req->prefix_data;

if (is_service && stderr_fd == -1) {
struct msg_header hdr = { .type = MSG_DATA_STDERR, .len = 0 };
libvchan_send(vchan, &hdr, (int)sizeof(hdr));
}

struct buffer remote_buffer = {
.data = malloc(max_chunk_size),
.buflen = max_chunk_size,
Expand Down Expand Up @@ -164,6 +169,14 @@ int qrexec_process_io(const struct process_io_request *req,
/* Convenience macros that eliminate a ton of error-prone boilerplate */
#define close_stdin() do { \
if (exit_on_stdin_eof) { \
/* If stdout is still open, send EOF */ \
if (stdout_fd != -1) { \
const struct msg_header hdr = { \
.type = stdout_msg_type, \
.len = 0, \
}; \
libvchan_send(vchan, &hdr, sizeof(hdr)); \
}; \
/* Set stdin_fd and stdout_fd to -1. \
* No need to close them as the process \
* will soon exit. */ \
Expand Down Expand Up @@ -320,9 +333,12 @@ int qrexec_process_io(const struct process_io_request *req,
* local FDs. However, don't exit yet, because there might
* still be some data in stdin_buf waiting to be flushed.
*/
if (stdout_fd != -1) {
/* Send EOF */
struct msg_header hdr = { .type = stdout_msg_type, .len = 0, };
libvchan_send(vchan, &hdr, (int)sizeof(hdr));
}
close_stdout();
close_stderr(stderr_fd);
stderr_fd = -1;
break;
}
if (prefix.len > 0 || (stdout_fd >= 0 && fds[FD_STDOUT].revents)) {
Expand Down
123 changes: 50 additions & 73 deletions qrexec/tests/socket/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,7 @@ def test_socket_null_argument_finds_service_for_empty_argument(self):

good_server.sendall(b"stdout data")
good_server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def _test_connect_socket_bad_config(self, forbidden_key):
Expand All @@ -717,7 +708,6 @@ def _test_connect_socket_bad_config(self, forbidden_key):

target, dom0 = self.execute_qubesrpc("qubes.SocketService+arg2", "domX")
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
Expand Down Expand Up @@ -757,14 +747,12 @@ def test_connect_socket_exit_on_stdin_eof(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"")
# Check for EOF on stdin
self.assertEqual(server.recvall(len(message) + 1), message)
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
self.assertEqual(target.recv_all_messages(),
[
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
])
self.check_dom0(dom0)
server.close()

Expand Down Expand Up @@ -792,15 +780,7 @@ def test_connect_socket_exit_on_stdout_eof(self):
# Trigger EOF on stdout
server.shutdown(socket.SHUT_WR)
# Server should exit
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"")
self.check_dom0(dom0)
server.close()

Expand Down Expand Up @@ -828,16 +808,7 @@ def test_connect_socket_no_metadata(self):

server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def test_connect_socket_tcp(self):
Expand Down Expand Up @@ -872,20 +843,13 @@ def _test_tcp_raw(self, family: int, service: str, host: str, port: int, accept=
self.assertEqual(server.recvall(len(message)), message)
server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
self.check_dom0(dom0)
return util.sort_messages(messages)
return target

def _test_tcp(self, family: int, service: str, host: str, port: int) -> None:
# No stderr
self.assertListEqual(
self.assertExpectedStdout(
self._test_tcp_raw(family, service, host, port),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
b"stdout data")

def test_connect_socket_tcp_port_from_arg(self):
socket_path = os.path.join(
Expand Down Expand Up @@ -922,13 +886,9 @@ def test_connect_socket_tcp_ipv6_service_arg(self):
host = "::1"
os.symlink(f"/dev/tcp", socket_path)
service = f"qubes.SocketService+{host.replace(':', '+')}+{port}"
self.assertListEqual(
self.assertExpectedStdout(
self._test_tcp_raw(socket.AF_INET6, service, host, port, skip=False),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
b"stdout data",
)

def _test_connect_socket_tcp_unexpected_host(self, host):
Expand All @@ -938,16 +898,9 @@ def _test_connect_socket_tcp_unexpected_host(self, host):
port = 65535
path = f"/dev/tcp/{host}"
os.symlink(path, socket_path)
messages = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}",
target = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}",
host, port, accept=False)
self.assertListEqual(
messages,
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\175\0\0\0"),
],
)
self.assertExpectedStdout(target, b"", exit_code=125)

def test_connect_socket_tcp_missing_host(self):
"""
Expand Down Expand Up @@ -1055,16 +1008,7 @@ def test_connect_socket(self):

server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def test_service_close_stdout_stderr_early(self):
Expand Down Expand Up @@ -1461,6 +1405,37 @@ def test_run_client_bidirectional_shutdown(self):
remote.close()
local.close()

def test_run_client_bidirectional_shutdown_early_exit(self):
try:
remote, local = socket.socketpair()
target_client = self.run_service(stdio=remote)
initial_data = b"stdout data\n"
target_client.send_message(qrexec.MSG_DATA_STDOUT, initial_data)
# FIXME: data can be received in multiple messages
self.assertEqual(local.recv(len(initial_data)), initial_data)
initial_data = b"stdin data\n"
local.sendall(initial_data)
self.assertStdoutMessages(target_client, initial_data, qrexec.MSG_DATA_STDIN)
target_client.send_message(qrexec.MSG_DATA_STDOUT, b"")
# Check that EOF got propagated on this side too, even though
# we still have a reference to the socket. This indicates that
# qrexec-client-vm shut down the socket for writing.
self.assertEqual(local.recv(1), b"")
with self.assertRaises(BrokenPipeError):
remote.send(b"a")
target_client.send_message(
qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)
)
local.shutdown(socket.SHUT_WR)
# Check that EOF received
self.assertEqual(target_client.recv_message(), (qrexec.MSG_DATA_STDIN, b""))
self.client.wait()
self.assertEqual(self.client.returncode, 42)
finally:
remote.close()
local.close()


def test_run_client_replace_chars(self):
target_client = self.run_service(options=["-t"])
target_client.send_message(
Expand Down Expand Up @@ -1496,7 +1471,8 @@ def test_run_client_failed(self):
)
# there should be no MSG_DATA_EXIT_CODE from qrexec-client-vm
# and also no MSG_DATA_STDIN after receiving MSG_DATA_EXIT_CODE
self.assertListEqual(target_client.recv_all_messages(), [])
self.assertListEqual(target_client.recv_all_messages(),
[(qrexec.MSG_DATA_STDIN, b"")])
self.assertEqual(self.client.stdout.read(), b"")
self.client.wait()
self.assertEqual(self.client.returncode, qrexec.QREXEC_EXIT_PROBLEM)
Expand Down Expand Up @@ -1586,7 +1562,8 @@ def test_run_client_with_local_proc_service_failed(self):
qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", qrexec.QREXEC_EXIT_PROBLEM)
)
# there should be no MSG_DATA_EXIT_CODE from qrexec-client-vm
self.assertListEqual(target_client.recv_all_messages(), [])
self.assertListEqual(target_client.recv_all_messages(),
[(qrexec.MSG_DATA_STDIN, b"")])
target_client.close()
self.assertEqual(self.client.stdout.read(), b"")
self.client.wait()
Expand Down
1 change: 1 addition & 0 deletions qrexec/tests/socket/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def connect_service_request(self, cmd, timeout=None):

source.accept()
source.handshake()
self.assertEqual(source.recv_message(), (qrexec.MSG_DATA_STDERR, b""))
return source

def test_run_dom0_command_and_connect_vm(self):
Expand Down

0 comments on commit d123700

Please sign in to comment.