diff --git a/src/unix_term.rs b/src/unix_term.rs index b8e0db2e..a6839b7d 100644 --- a/src/unix_term.rs +++ b/src/unix_term.rs @@ -1,10 +1,9 @@ use std::env; use std::fmt::Display; use std::fs; -use std::io; -use std::io::{BufRead, BufReader}; +use std::io::{self, BufRead, BufReader}; use std::mem; -use std::os::unix::io::AsRawFd; +use std::os::fd::{AsRawFd, RawFd}; use std::str; #[cfg(not(target_os = "macos"))] @@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*; pub(crate) const DEFAULT_WIDTH: u16 = 80; #[inline] -pub(crate) fn is_a_terminal(out: &Term) -> bool { +pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool { unsafe { libc::isatty(out.as_raw_fd()) != 0 } } @@ -66,41 +65,73 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> { } } -pub(crate) fn read_secure() -> io::Result { - let mut f_tty; - let fd = unsafe { - if libc::isatty(libc::STDIN_FILENO) == 1 { - f_tty = None; - libc::STDIN_FILENO - } else { - let f = fs::OpenOptions::new() - .read(true) - .write(true) - .open("/dev/tty")?; - let fd = f.as_raw_fd(); - f_tty = Some(BufReader::new(f)); - fd +enum Input { + Stdin(io::Stdin), + File(T), +} + +fn unbuffered_input() -> io::Result> { + let stdin = io::stdin(); + if is_a_terminal(&stdin) { + Ok(Input::Stdin(stdin)) + } else { + let f = fs::OpenOptions::new() + .read(true) + .write(true) + .open("/dev/tty")?; + Ok(Input::File(f)) + } +} + +fn buffered_input() -> io::Result>> { + Ok(match unbuffered_input()? { + Input::Stdin(s) => Input::Stdin(s), + Input::File(f) => Input::File(BufReader::new(f)), + }) +} + +// NB: this is not a full BufRead implementation because io::Stdin does not implement BufRead. +impl Input { + fn read_line(&mut self, buf: &mut String) -> io::Result { + match self { + Self::Stdin(s) => s.read_line(buf), + Self::File(f) => f.read_line(buf), } - }; + } +} + +impl AsRawFd for Input { + fn as_raw_fd(&self) -> RawFd { + match self { + Self::Stdin(s) => s.as_raw_fd(), + Self::File(f) => f.as_raw_fd(), + } + } +} + +impl AsRawFd for Input> { + fn as_raw_fd(&self) -> RawFd { + match self { + Self::Stdin(s) => s.as_raw_fd(), + Self::File(f) => f.get_ref().as_raw_fd(), + } + } +} + +pub(crate) fn read_secure() -> io::Result { + let mut input = buffered_input()?; let mut termios = mem::MaybeUninit::uninit(); - c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?; + c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; let mut termios = unsafe { termios.assume_init() }; let original = termios; termios.c_lflag &= !libc::ECHO; - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?; + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?; let mut rv = String::new(); - let read_rv = if let Some(f) = &mut f_tty { - f.read_line(&mut rv) - } else { - io::stdin().read_line(&mut rv) - }; + let read_rv = input.read_line(&mut rv); - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?; - - // Ensure the fd is only closed after everything has been restored. - drop(f_tty); + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?; read_rv.map(|_| { let len = rv.trim_end_matches(&['\r', '\n'][..]).len(); @@ -109,7 +140,7 @@ pub(crate) fn read_secure() -> io::Result { }) } -fn poll_fd(fd: i32, timeout: i32) -> io::Result { +fn poll_fd(fd: RawFd, timeout: i32) -> io::Result { let mut pollfd = libc::pollfd { fd, events: libc::POLLIN, @@ -124,7 +155,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result { } #[cfg(target_os = "macos")] -fn select_fd(fd: i32, timeout: i32) -> io::Result { +fn select_fd(fd: RawFd, timeout: i32) -> io::Result { unsafe { let mut read_fd_set: libc::fd_set = mem::zeroed(); @@ -156,7 +187,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result { } } -fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result { +fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result { // There is a bug on macos that ttys cannot be polled, only select() // works. However given how problematic select is in general, we // normally want to use poll there too. @@ -169,7 +200,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result { poll_fd(fd, timeout) } -fn read_single_char(fd: i32) -> io::Result> { +fn read_single_char(fd: RawFd) -> io::Result> { // timeout of zero means that it will not block let is_ready = select_or_poll_term_fd(fd, 0)?; @@ -188,7 +219,7 @@ fn read_single_char(fd: i32) -> io::Result> { // Similar to libc::read. Read count bytes into slice buf from descriptor fd. // If successful, return the number of bytes read. // Will return an error if nothing was read, i.e when called at end of file. -fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result { +fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result { let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) }; if read < 0 { Err(io::Error::last_os_error()) @@ -207,7 +238,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result { } } -fn read_single_key_impl(fd: i32) -> Result { +fn read_single_key_impl(fd: RawFd) -> Result { loop { match read_single_char(fd)? { Some('\x1b') => { @@ -301,27 +332,17 @@ fn read_single_key_impl(fd: i32) -> Result { } pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { - let tty_f; - let fd = unsafe { - if libc::isatty(libc::STDIN_FILENO) == 1 { - libc::STDIN_FILENO - } else { - tty_f = fs::OpenOptions::new() - .read(true) - .write(true) - .open("/dev/tty")?; - tty_f.as_raw_fd() - } - }; + let input = unbuffered_input()?; + let mut termios = core::mem::MaybeUninit::uninit(); - c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?; + c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; let mut termios = unsafe { termios.assume_init() }; let original = termios; unsafe { libc::cfmakeraw(&mut termios) }; termios.c_oflag = original.c_oflag; - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?; - let rv: io::Result = read_single_key_impl(fd); - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?; + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?; + let rv: io::Result = read_single_key_impl(input.as_raw_fd()); + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?; // if the user hit ^C we want to signal SIGINT to outselves. if let Err(ref err) = rv {