Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always send EOF on streams #164

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions daemon/qrexec-daemon-common.c
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,15 @@ int prepare_local_fds(struct qrexec_parsed_command *command, struct buffer *stdi
// See also qrexec-agent/qrexec-agent-data.c
static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit_code)
{
struct msg_header hdr = {
.type = MSG_DATA_STDOUT,
.len = 0,
const struct msg_header hdr[2] = {
{
.type = MSG_DATA_STDERR,
.len = 0,
},
{
.type = MSG_DATA_STDOUT,
.len = 0,
},
};

LOG(ERROR, "failed to spawn process, exiting");
Expand All @@ -404,7 +410,7 @@ static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit
* when we support sockets as a local process.
*/
if (is_service) {
libvchan_send(data_vchan, &hdr, sizeof(hdr));
libvchan_send(data_vchan, hdr, sizeof(hdr));
send_exit_code(data_vchan, exit_code);
}
}
Expand Down
20 changes: 18 additions & 2 deletions libqrexec/process_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ int qrexec_process_io(const struct process_io_request *req,
struct timespec normal_timeout = { 10, 0 };
struct prefix_data empty = { 0, 0 }, prefix = req->prefix_data;

if (is_service && stderr_fd == -1) {
struct msg_header hdr = { .type = MSG_DATA_STDERR, .len = 0 };
libvchan_send(vchan, &hdr, (int)sizeof(hdr));
}

struct buffer remote_buffer = {
.data = malloc(max_chunk_size),
.buflen = max_chunk_size,
Expand Down Expand Up @@ -164,6 +169,14 @@ int qrexec_process_io(const struct process_io_request *req,
/* Convenience macros that eliminate a ton of error-prone boilerplate */
#define close_stdin() do { \
if (exit_on_stdin_eof) { \
/* If stdout is still open, send EOF */ \
if (stdout_fd != -1) { \
const struct msg_header hdr = { \
.type = stdout_msg_type, \
.len = 0, \
}; \
libvchan_send(vchan, &hdr, sizeof(hdr)); \
}; \
/* Set stdin_fd and stdout_fd to -1. \
* No need to close them as the process \
* will soon exit. */ \
Expand Down Expand Up @@ -320,9 +333,12 @@ int qrexec_process_io(const struct process_io_request *req,
* local FDs. However, don't exit yet, because there might
* still be some data in stdin_buf waiting to be flushed.
*/
if (stdout_fd != -1) {
/* Send EOF */
struct msg_header hdr = { .type = stdout_msg_type, .len = 0, };
libvchan_send(vchan, &hdr, (int)sizeof(hdr));
}
close_stdout();
close_stderr(stderr_fd);
stderr_fd = -1;
break;
}
if (prefix.len > 0 || (stdout_fd >= 0 && fds[FD_STDOUT].revents)) {
Expand Down
123 changes: 50 additions & 73 deletions qrexec/tests/socket/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,16 +696,7 @@ def test_socket_null_argument_finds_service_for_empty_argument(self):

good_server.sendall(b"stdout data")
good_server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def _test_connect_socket_bad_config(self, forbidden_key):
Expand All @@ -725,7 +716,6 @@ def _test_connect_socket_bad_config(self, forbidden_key):

target, dom0 = self.execute_qubesrpc("qubes.SocketService+arg2", "domX")
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
Expand Down Expand Up @@ -765,14 +755,12 @@ def test_connect_socket_exit_on_stdin_eof(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"")
# Check for EOF on stdin
self.assertEqual(server.recvall(len(message) + 1), message)
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
self.assertEqual(target.recv_all_messages(),
[
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
])
self.check_dom0(dom0)
server.close()

Expand Down Expand Up @@ -800,15 +788,7 @@ def test_connect_socket_exit_on_stdout_eof(self):
# Trigger EOF on stdout
server.shutdown(socket.SHUT_WR)
# Server should exit
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"")
self.check_dom0(dom0)
server.close()

Expand Down Expand Up @@ -836,16 +816,7 @@ def test_connect_socket_no_metadata(self):

server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def test_connect_socket_tcp(self):
Expand Down Expand Up @@ -880,20 +851,13 @@ def _test_tcp_raw(self, family: int, service: str, host: str, port: int, accept=
self.assertEqual(server.recvall(len(message)), message)
server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
self.check_dom0(dom0)
return util.sort_messages(messages)
return target

def _test_tcp(self, family: int, service: str, host: str, port: int) -> None:
# No stderr
self.assertListEqual(
self.assertExpectedStdout(
self._test_tcp_raw(family, service, host, port),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
b"stdout data")

def test_connect_socket_tcp_port_from_arg(self):
socket_path = os.path.join(
Expand Down Expand Up @@ -930,13 +894,9 @@ def test_connect_socket_tcp_ipv6_service_arg(self):
host = "::1"
os.symlink(f"/dev/tcp", socket_path)
service = f"qubes.SocketService+{host.replace(':', '+')}+{port}"
self.assertListEqual(
self.assertExpectedStdout(
self._test_tcp_raw(socket.AF_INET6, service, host, port, skip=False),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
b"stdout data",
)

def _test_connect_socket_tcp_unexpected_host(self, host):
Expand All @@ -946,16 +906,9 @@ def _test_connect_socket_tcp_unexpected_host(self, host):
port = 65535
path = f"/dev/tcp/{host}"
os.symlink(path, socket_path)
messages = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}",
target = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}",
host, port, accept=False)
self.assertListEqual(
messages,
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\175\0\0\0"),
],
)
self.assertExpectedStdout(target, b"", exit_code=125)

def test_connect_socket_tcp_missing_host(self):
"""
Expand Down Expand Up @@ -1063,16 +1016,7 @@ def test_connect_socket(self):

server.sendall(b"stdout data")
server.close()
messages = target.recv_all_messages()
# No stderr
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"stdout data"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"stdout data")
self.check_dom0(dom0)

def test_service_close_stdout_stderr_early(self):
Expand Down Expand Up @@ -1469,6 +1413,37 @@ def test_run_client_bidirectional_shutdown(self):
remote.close()
local.close()

def test_run_client_bidirectional_shutdown_early_exit(self):
try:
remote, local = socket.socketpair()
target_client = self.run_service(stdio=remote)
initial_data = b"stdout data\n"
target_client.send_message(qrexec.MSG_DATA_STDOUT, initial_data)
# FIXME: data can be received in multiple messages
self.assertEqual(local.recv(len(initial_data)), initial_data)
initial_data = b"stdin data\n"
local.sendall(initial_data)
self.assertStdoutMessages(target_client, initial_data, qrexec.MSG_DATA_STDIN)
target_client.send_message(qrexec.MSG_DATA_STDOUT, b"")
# Check that EOF got propagated on this side too, even though
# we still have a reference to the socket. This indicates that
# qrexec-client-vm shut down the socket for writing.
self.assertEqual(local.recv(1), b"")
with self.assertRaises(BrokenPipeError):
remote.send(b"a")
target_client.send_message(
qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)
)
local.shutdown(socket.SHUT_WR)
# Check that EOF received
self.assertEqual(target_client.recv_message(), (qrexec.MSG_DATA_STDIN, b""))
self.client.wait()
self.assertEqual(self.client.returncode, 42)
finally:
remote.close()
local.close()


def test_run_client_replace_chars(self):
target_client = self.run_service(options=["-t"])
target_client.send_message(
Expand Down Expand Up @@ -1504,7 +1479,8 @@ def test_run_client_failed(self):
)
# there should be no MSG_DATA_EXIT_CODE from qrexec-client-vm
# and also no MSG_DATA_STDIN after receiving MSG_DATA_EXIT_CODE
self.assertListEqual(target_client.recv_all_messages(), [])
self.assertListEqual(target_client.recv_all_messages(),
[(qrexec.MSG_DATA_STDIN, b"")])
self.assertEqual(self.client.stdout.read(), b"")
self.client.wait()
self.assertEqual(self.client.returncode, qrexec.QREXEC_EXIT_PROBLEM)
Expand Down Expand Up @@ -1594,7 +1570,8 @@ def test_run_client_with_local_proc_service_failed(self):
qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", qrexec.QREXEC_EXIT_PROBLEM)
)
# there should be no MSG_DATA_EXIT_CODE from qrexec-client-vm
self.assertListEqual(target_client.recv_all_messages(), [])
self.assertListEqual(target_client.recv_all_messages(),
[(qrexec.MSG_DATA_STDIN, b"")])
target_client.close()
self.assertEqual(self.client.stdout.read(), b"")
self.client.wait()
Expand Down
1 change: 1 addition & 0 deletions qrexec/tests/socket/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def connect_service_request(self, cmd, timeout=None):

source.accept()
source.handshake()
self.assertEqual(source.recv_message(), (qrexec.MSG_DATA_STDERR, b""))
return source

def test_run_dom0_command_and_connect_vm(self):
Expand Down