Skip to content

Commit

Permalink
Improve type safety, extract identical code
Browse files Browse the repository at this point in the history
Avoid fragility of tracking objects and their FDs separately.
  • Loading branch information
tamird committed Dec 17, 2024
1 parent 35d0451 commit 44a44a5
Showing 1 changed file with 73 additions and 52 deletions.
125 changes: 73 additions & 52 deletions src/unix_term.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::env;
use std::fmt::Display;
use std::fs;
use std::io;
use std::io::{BufRead, BufReader};
use std::io::{self, BufRead, BufReader};
use std::mem;
use std::os::unix::io::AsRawFd;
use std::os::fd::{AsRawFd, RawFd};
use std::str;

#[cfg(not(target_os = "macos"))]
Expand All @@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*;
pub(crate) const DEFAULT_WIDTH: u16 = 80;

#[inline]
pub(crate) fn is_a_terminal(out: &Term) -> bool {
pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool {
unsafe { libc::isatty(out.as_raw_fd()) != 0 }
}

Expand Down Expand Up @@ -66,41 +65,73 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> {
}
}

pub(crate) fn read_secure() -> io::Result<String> {
let mut f_tty;
let fd = unsafe {
if libc::isatty(libc::STDIN_FILENO) == 1 {
f_tty = None;
libc::STDIN_FILENO
} else {
let f = fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/tty")?;
let fd = f.as_raw_fd();
f_tty = Some(BufReader::new(f));
fd
enum Input<T> {
Stdin(io::Stdin),
File(T),
}

fn unbuffered_input() -> io::Result<Input<fs::File>> {
let stdin = io::stdin();
if is_a_terminal(&stdin) {
Ok(Input::Stdin(stdin))
} else {
let f = fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/tty")?;
Ok(Input::File(f))
}
}

fn buffered_input() -> io::Result<Input<BufReader<fs::File>>> {
Ok(match unbuffered_input()? {
Input::Stdin(s) => Input::Stdin(s),
Input::File(f) => Input::File(BufReader::new(f)),
})
}

// NB: this is not a full BufRead implementation because io::Stdin does not implement BufRead.
impl<T: BufRead> Input<T> {
fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
match self {
Self::Stdin(s) => s.read_line(buf),
Self::File(f) => f.read_line(buf),
}
};
}
}

impl AsRawFd for Input<fs::File> {
fn as_raw_fd(&self) -> RawFd {
match self {
Self::Stdin(s) => s.as_raw_fd(),
Self::File(f) => f.as_raw_fd(),
}
}
}

impl AsRawFd for Input<BufReader<fs::File>> {
fn as_raw_fd(&self) -> RawFd {
match self {
Self::Stdin(s) => s.as_raw_fd(),
Self::File(f) => f.get_ref().as_raw_fd(),
}
}
}

pub(crate) fn read_secure() -> io::Result<String> {
let mut input = buffered_input()?;

let mut termios = mem::MaybeUninit::uninit();
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
let mut termios = unsafe { termios.assume_init() };
let original = termios;
termios.c_lflag &= !libc::ECHO;
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?;
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?;
let mut rv = String::new();

let read_rv = if let Some(f) = &mut f_tty {
f.read_line(&mut rv)
} else {
io::stdin().read_line(&mut rv)
};
let read_rv = input.read_line(&mut rv);

c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?;

// Ensure the fd is only closed after everything has been restored.
drop(f_tty);
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?;

read_rv.map(|_| {
let len = rv.trim_end_matches(&['\r', '\n'][..]).len();
Expand All @@ -109,7 +140,7 @@ pub(crate) fn read_secure() -> io::Result<String> {
})
}

fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn poll_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
let mut pollfd = libc::pollfd {
fd,
events: libc::POLLIN,
Expand All @@ -124,7 +155,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
}

#[cfg(target_os = "macos")]
fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn select_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
unsafe {
let mut read_fd_set: libc::fd_set = mem::zeroed();

Expand Down Expand Up @@ -156,7 +187,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
}
}

fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
// There is a bug on macos that ttys cannot be polled, only select()
// works. However given how problematic select is in general, we
// normally want to use poll there too.
Expand All @@ -169,7 +200,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
poll_fd(fd, timeout)
}

fn read_single_char(fd: i32) -> io::Result<Option<char>> {
fn read_single_char(fd: RawFd) -> io::Result<Option<char>> {
// timeout of zero means that it will not block
let is_ready = select_or_poll_term_fd(fd, 0)?;

Expand All @@ -188,7 +219,7 @@ fn read_single_char(fd: i32) -> io::Result<Option<char>> {
// Similar to libc::read. Read count bytes into slice buf from descriptor fd.
// If successful, return the number of bytes read.
// Will return an error if nothing was read, i.e when called at end of file.
fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result<u8> {
let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) };
if read < 0 {
Err(io::Error::last_os_error())
Expand All @@ -207,7 +238,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
}
}

fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
fn read_single_key_impl(fd: RawFd) -> Result<Key, io::Error> {
loop {
match read_single_char(fd)? {
Some('\x1b') => {
Expand Down Expand Up @@ -301,27 +332,17 @@ fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
}

pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
let tty_f;
let fd = unsafe {
if libc::isatty(libc::STDIN_FILENO) == 1 {
libc::STDIN_FILENO
} else {
tty_f = fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/tty")?;
tty_f.as_raw_fd()
}
};
let input = unbuffered_input()?;

let mut termios = core::mem::MaybeUninit::uninit();
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
let mut termios = unsafe { termios.assume_init() };
let original = termios;
unsafe { libc::cfmakeraw(&mut termios) };
termios.c_oflag = original.c_oflag;
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?;
let rv: io::Result<Key> = read_single_key_impl(fd);
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?;
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?;
let rv: io::Result<Key> = read_single_key_impl(input.as_raw_fd());
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?;

// if the user hit ^C we want to signal SIGINT to outselves.
if let Err(ref err) = rv {
Expand Down

0 comments on commit 44a44a5

Please sign in to comment.