diff --git a/daemon/qrexec-daemon-common.c b/daemon/qrexec-daemon-common.c index fa4a89ce..3d42ab5e 100644 --- a/daemon/qrexec-daemon-common.c +++ b/daemon/qrexec-daemon-common.c @@ -386,9 +386,15 @@ int prepare_local_fds(struct qrexec_parsed_command *command, struct buffer *stdi // See also qrexec-agent/qrexec-agent-data.c static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit_code) { - struct msg_header hdr = { - .type = MSG_DATA_STDOUT, - .len = 0, + const struct msg_header hdr[2] = { + { + .type = MSG_DATA_STDERR, + .len = 0, + }, + { + .type = MSG_DATA_STDOUT, + .len = 0, + }, }; LOG(ERROR, "failed to spawn process, exiting"); @@ -404,7 +410,7 @@ static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit * when we support sockets as a local process. */ if (is_service) { - libvchan_send(data_vchan, &hdr, sizeof(hdr)); + libvchan_send(data_vchan, hdr, sizeof(hdr)); send_exit_code(data_vchan, exit_code); } } diff --git a/libqrexec/process_io.c b/libqrexec/process_io.c index c65b48b6..c1dd8714 100644 --- a/libqrexec/process_io.c +++ b/libqrexec/process_io.c @@ -124,6 +124,11 @@ int qrexec_process_io(const struct process_io_request *req, struct timespec normal_timeout = { 10, 0 }; struct prefix_data empty = { 0, 0 }, prefix = req->prefix_data; + if (is_service && stderr_fd == -1) { + struct msg_header hdr = { .type = MSG_DATA_STDERR, .len = 0 }; + libvchan_send(vchan, &hdr, (int)sizeof(hdr)); + } + struct buffer remote_buffer = { .data = malloc(max_chunk_size), .buflen = max_chunk_size, @@ -164,6 +169,14 @@ int qrexec_process_io(const struct process_io_request *req, /* Convenience macros that eliminate a ton of error-prone boilerplate */ #define close_stdin() do { \ if (exit_on_stdin_eof) { \ + /* If stdout is still open, send EOF */ \ + if (stdout_fd != -1) { \ + const struct msg_header hdr = { \ + .type = stdout_msg_type, \ + .len = 0, \ + }; \ + libvchan_send(vchan, &hdr, sizeof(hdr)); \ + }; \ /* Set stdin_fd and stdout_fd to -1. \ * No need to close them as the process \ * will soon exit. */ \ @@ -320,9 +333,12 @@ int qrexec_process_io(const struct process_io_request *req, * local FDs. However, don't exit yet, because there might * still be some data in stdin_buf waiting to be flushed. */ + if (stdout_fd != -1) { + /* Send EOF */ + struct msg_header hdr = { .type = stdout_msg_type, .len = 0, }; + libvchan_send(vchan, &hdr, (int)sizeof(hdr)); + } close_stdout(); - close_stderr(stderr_fd); - stderr_fd = -1; break; } if (prefix.len > 0 || (stdout_fd >= 0 && fds[FD_STDOUT].revents)) { diff --git a/qrexec/tests/socket/agent.py b/qrexec/tests/socket/agent.py index 7b2f873a..af9c15d7 100644 --- a/qrexec/tests/socket/agent.py +++ b/qrexec/tests/socket/agent.py @@ -696,16 +696,7 @@ def test_socket_null_argument_finds_service_for_empty_argument(self): good_server.sendall(b"stdout data") good_server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def _test_connect_socket_bad_config(self, forbidden_key): @@ -725,7 +716,6 @@ def _test_connect_socket_bad_config(self, forbidden_key): target, dom0 = self.execute_qubesrpc("qubes.SocketService+arg2", "domX") messages = target.recv_all_messages() - # No stderr self.assertListEqual( util.sort_messages(messages), [ @@ -765,14 +755,12 @@ def test_connect_socket_exit_on_stdin_eof(self): target.send_message(qrexec.MSG_DATA_STDIN, b"") # Check for EOF on stdin self.assertEqual(server.recvall(len(message) + 1), message) - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), + self.assertEqual(target.recv_all_messages(), [ + (qrexec.MSG_DATA_STDERR, b""), + (qrexec.MSG_DATA_STDOUT, b""), (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + ]) self.check_dom0(dom0) server.close() @@ -800,15 +788,7 @@ def test_connect_socket_exit_on_stdout_eof(self): # Trigger EOF on stdout server.shutdown(socket.SHUT_WR) # Server should exit - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"") self.check_dom0(dom0) server.close() @@ -836,16 +816,7 @@ def test_connect_socket_no_metadata(self): server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def test_connect_socket_tcp(self): @@ -880,20 +851,13 @@ def _test_tcp_raw(self, family: int, service: str, host: str, port: int, accept= self.assertEqual(server.recvall(len(message)), message) server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() self.check_dom0(dom0) - return util.sort_messages(messages) + return target def _test_tcp(self, family: int, service: str, host: str, port: int) -> None: - # No stderr - self.assertListEqual( + self.assertExpectedStdout( self._test_tcp_raw(family, service, host, port), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + b"stdout data") def test_connect_socket_tcp_port_from_arg(self): socket_path = os.path.join( @@ -930,13 +894,9 @@ def test_connect_socket_tcp_ipv6_service_arg(self): host = "::1" os.symlink(f"/dev/tcp", socket_path) service = f"qubes.SocketService+{host.replace(':', '+')}+{port}" - self.assertListEqual( + self.assertExpectedStdout( self._test_tcp_raw(socket.AF_INET6, service, host, port, skip=False), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], + b"stdout data", ) def _test_connect_socket_tcp_unexpected_host(self, host): @@ -946,16 +906,9 @@ def _test_connect_socket_tcp_unexpected_host(self, host): port = 65535 path = f"/dev/tcp/{host}" os.symlink(path, socket_path) - messages = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}", + target = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}", host, port, accept=False) - self.assertListEqual( - messages, - [ - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_STDERR, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\175\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"", exit_code=125) def test_connect_socket_tcp_missing_host(self): """ @@ -1063,16 +1016,7 @@ def test_connect_socket(self): server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def test_service_close_stdout_stderr_early(self): @@ -1469,6 +1413,37 @@ def test_run_client_bidirectional_shutdown(self): remote.close() local.close() + def test_run_client_bidirectional_shutdown_early_exit(self): + try: + remote, local = socket.socketpair() + target_client = self.run_service(stdio=remote) + initial_data = b"stdout data\n" + target_client.send_message(qrexec.MSG_DATA_STDOUT, initial_data) + # FIXME: data can be received in multiple messages + self.assertEqual(local.recv(len(initial_data)), initial_data) + initial_data = b"stdin data\n" + local.sendall(initial_data) + self.assertStdoutMessages(target_client, initial_data, qrexec.MSG_DATA_STDIN) + target_client.send_message(qrexec.MSG_DATA_STDOUT, b"") + # Check that EOF got propagated on this side too, even though + # we still have a reference to the socket. This indicates that + # qrexec-client-vm shut down the socket for writing. + self.assertEqual(local.recv(1), b"") + with self.assertRaises(BrokenPipeError): + remote.send(b"a") + target_client.send_message( + qrexec.MSG_DATA_EXIT_CODE, struct.pack("