diff --git a/.gitignore b/.gitignore index 9cc7b3d3..7121fe11 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ target/ **/*.rs.bk Cargo.lock +.cargo/ +wintun.dll diff --git a/Cargo.toml b/Cargo.toml index e7353857..5cadf464 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tun" -version = "0.6.1" +version = "0.6.2" edition = "2021" authors = ["meh. "] diff --git a/README.md b/README.md index 7c468188..20b8a447 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ First, add the following to your `Cargo.toml`: ```toml [dependencies] -tun = "0.6.1" +tun = "0.6.2" ``` Next, add this to your crate root: @@ -21,7 +21,7 @@ If you want to use the TUN interface with mio/tokio, you need to enable the `asy ```toml [dependencies] -tun = { version = "0.6.1", features = ["async"] } +tun = { version = "0.6.2", features = ["async"] } ``` Example diff --git a/src/async/win/device.rs b/src/async/win/device.rs index 8b67037d..89d44978 100644 --- a/src/async/win/device.rs +++ b/src/async/win/device.rs @@ -21,7 +21,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_util::codec::Framed; use crate::device::Device as D; -use crate::platform::{Device, Queue}; +use crate::platform::Device; use crate::r#async::codec::*; pub struct AsyncDevice { @@ -56,10 +56,9 @@ impl AsyncRead for AsyncDevice { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let rbuf = buf.initialize_unfilled(); - match Pin::new(&mut self.inner).poll_read(cx, rbuf) { - Poll::Ready(Ok(n)) => { - buf.advance(n); + //let rbuf = buf.initialize_unfilled(); + match Pin::new(&mut self.inner).poll_read(cx, buf) { + Poll::Ready(Ok(_)) => { Poll::Ready(Ok(())) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), @@ -92,67 +91,69 @@ impl AsyncWrite for AsyncDevice { } } -pub struct AsyncQueue { - inner: Queue, -} - -impl AsyncQueue { - /// Create a new `AsyncQueue` wrapping around a `Queue`. - pub fn new(queue: Queue) -> io::Result { - Ok(AsyncQueue { inner: queue }) - } - /// Returns a shared reference to the underlying Queue object - pub fn get_ref(&self) -> &Queue { - &self.inner - } - - /// Returns a mutable reference to the underlying Queue object - pub fn get_mut(&mut self) -> &mut Queue { - &mut self.inner - } - - /// Consumes this AsyncQueue and return a Framed object (unified Stream and Sink interface) - pub fn into_framed(self) -> Framed { - let codec = TunPacketCodec::new(false, 1512); - Framed::new(self, codec) - } -} - -impl AsyncRead for AsyncQueue { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let rbuf = buf.initialize_unfilled(); - match Pin::new(&mut self.inner).poll_read(cx, rbuf) { - Poll::Ready(Ok(n)) => { - buf.advance(n); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } -} - -impl AsyncWrite for AsyncQueue { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.inner.write(buf) { - Ok(n) => Poll::Ready(Ok(n)), - Err(e) => Poll::Ready(Err(e)), - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} +pub struct AsyncQueue; + +// pub struct AsyncQueue { +// inner: Queue, +// } + +// impl AsyncQueue { +// /// Create a new `AsyncQueue` wrapping around a `Queue`. +// pub fn new(queue: Queue) -> io::Result { +// Ok(AsyncQueue { inner: queue }) +// } +// /// Returns a shared reference to the underlying Queue object +// pub fn get_ref(&self) -> &Queue { +// &self.inner +// } + +// /// Returns a mutable reference to the underlying Queue object +// pub fn get_mut(&mut self) -> &mut Queue { +// &mut self.inner +// } + +// /// Consumes this AsyncQueue and return a Framed object (unified Stream and Sink interface) +// pub fn into_framed(self) -> Framed { +// let codec = TunPacketCodec::new(false, 1512); +// Framed::new(self, codec) +// } +// } + +// impl AsyncRead for AsyncQueue { +// fn poll_read( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &mut ReadBuf<'_>, +// ) -> Poll> { +// let rbuf = buf.initialize_unfilled(); +// match Pin::new(&mut self.inner).poll_read(cx, rbuf) { +// Poll::Ready(Ok(n)) => { +// buf.advance(n); +// Poll::Ready(Ok(())) +// } +// Poll::Ready(Err(e)) => Poll::Ready(Err(e)), +// Poll::Pending => Poll::Pending, +// } +// } +// } + +// impl AsyncWrite for AsyncQueue { +// fn poll_write( +// mut self: Pin<&mut Self>, +// _cx: &mut Context<'_>, +// buf: &[u8], +// ) -> Poll> { +// match self.inner.write(buf) { +// Ok(n) => Poll::Ready(Ok(n)), +// Err(e) => Poll::Ready(Err(e)), +// } +// } + +// fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { +// Poll::Ready(Ok(())) +// } + +// fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { +// Poll::Ready(Ok(())) +// } +// } diff --git a/src/platform/windows/device.rs b/src/platform/windows/device.rs index 53bee9e4..6b9eda8a 100644 --- a/src/platform/windows/device.rs +++ b/src/platform/windows/device.rs @@ -14,12 +14,12 @@ use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr}; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; +use std::sync::Arc; use std::thread; use std::vec::Vec; +use bytes::BufMut; +use tokio::io::ReadBuf; use wintun::Session; use crate::configuration::Configuration; @@ -28,8 +28,10 @@ use crate::error::*; /// A TUN device using the wintun driver. pub struct Device { - queue: Queue, - mtu: usize, + session: Arc, + receiver: tokio::sync::mpsc::UnboundedReceiver>, + _task: thread::JoinHandle<()>, + mtu:usize } impl Device { @@ -43,20 +45,31 @@ impl Device { Err(_) => wintun::Adapter::create(&wintun, tun_name, tun_name, guid)?, }; - let address = config.address.unwrap_or(Ipv4Addr::new(10, 1, 0, 2)); - let mask = config.netmask.unwrap_or(Ipv4Addr::new(255, 255, 255, 0)); + let address = config.address.ok_or(Error::InvalidConfig)?; + let mask = config.netmask.ok_or(Error::InvalidConfig)?; let gateway = config.destination.map(IpAddr::from); adapter.set_network_addresses_tuple(IpAddr::V4(address), IpAddr::V4(mask), gateway)?; let mtu = config.mtu.unwrap_or(1500) as usize; - let session = adapter.start_session(wintun::MAX_RING_CAPACITY)?; + let session = Arc::new(adapter.start_session(wintun::MAX_RING_CAPACITY)?); + + let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel::>(); + + let session_reader = session.clone(); + let task = thread::spawn(move || { + loop { + let packet = session_reader.receive_blocking().unwrap(); + let bytes = packet.bytes().to_vec(); + // dbg!(&bytes); + receiver_tx.send(bytes).unwrap(); + } + }); let mut device = Device { - queue: Queue { - session: Arc::new(session), - cached: Arc::new(Mutex::new(Vec::with_capacity(mtu))), - }, - mtu, + session, + receiver: receiver_rx, + _task: task, + mtu }; // This is not needed since we use netsh to set the address. @@ -66,47 +79,80 @@ impl Device { } pub fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.queue).poll_read(cx, buf) + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + match std::task::ready!(self.receiver.poll_recv(cx)) { + Some(bytes) => { + buf.put_slice(&bytes[..]); + std::task::Poll::Ready(Ok(())) + } + None => std::task::Poll::Ready(Ok(())), + } + } + + pub fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let mut write_pack = self.session.allocate_send_packet(buf.len() as u16)?; + write_pack.bytes_mut().copy_from_slice(buf.as_ref()); + self.session.send_packet(write_pack); + std::task::Poll::Ready(Ok(buf.len())) + } + + pub fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + pub fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) } } impl Read for Device { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.queue.read(buf) - } - - fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { - self.queue.read_vectored(bufs) + match self.receiver.blocking_recv(){ + Some(pkt) =>{ + buf.clone_from_slice(&pkt[..]); + return Ok(pkt.len()) + }, + None => Ok(0), + } } } impl Write for Device { fn write(&mut self, buf: &[u8]) -> io::Result { - self.queue.write(buf) - } - - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - self.queue.write_vectored(bufs) + let len = buf.len(); + let mut write_pack = self.session.allocate_send_packet(len as u16)?; + write_pack.bytes_mut().copy_from_slice(buf.as_ref()); + self.session.send_packet(write_pack); + Ok(len) } fn flush(&mut self) -> io::Result<()> { - self.queue.flush() + Ok(()) } } impl D for Device { - type Queue = Queue; + type Queue = Device; fn name(&self) -> Result { - Ok(self.queue.session.get_adapter().get_name()?) + Ok(self.session.get_adapter().get_name()?) } fn set_name(&mut self, value: &str) -> Result<()> { - self.queue.session.get_adapter().set_name(value)?; + self.session.get_adapter().set_name(value)?; Ok(()) } @@ -115,7 +161,7 @@ impl D for Device { } fn address(&self) -> Result { - let addresses = self.queue.session.get_adapter().get_addresses()?; + let addresses = self.session.get_adapter().get_addresses()?; addresses .iter() .find_map(|a| match a { @@ -126,14 +172,13 @@ impl D for Device { } fn set_address(&mut self, value: Ipv4Addr) -> Result<()> { - self.queue.session.get_adapter().set_address(value)?; + self.session.get_adapter().set_address(value)?; Ok(()) } fn destination(&self) -> Result { // It's just the default gateway in windows. - self.queue - .session + self.session .get_adapter() .get_gateways()? .iter() @@ -146,7 +191,7 @@ impl D for Device { fn set_destination(&mut self, value: Ipv4Addr) -> Result<()> { // It's just set the default gateway in windows. - self.queue.session.get_adapter().set_gateway(Some(value))?; + self.session.get_adapter().set_gateway(Some(value))?; Ok(()) } @@ -162,7 +207,6 @@ impl D for Device { fn netmask(&self) -> Result { let current_addr = self.address()?; let netmask = self - .queue .session .get_adapter() .get_netmask_of_address(&IpAddr::V4(current_addr))?; @@ -173,7 +217,7 @@ impl D for Device { } fn set_netmask(&mut self, value: Ipv4Addr) -> Result<()> { - self.queue.session.get_adapter().set_netmask(value)?; + self.session.get_adapter().set_netmask(value)?; Ok(()) } @@ -183,121 +227,14 @@ impl D for Device { fn set_mtu(&mut self, value: i32) -> Result<()> { self.mtu = value as usize; - self.queue.cached = Arc::new(Mutex::new(Vec::with_capacity(self.mtu))); Ok(()) } fn queue(&mut self, _index: usize) -> Option<&mut Self::Queue> { - Some(&mut self.queue) - } -} - -pub struct Queue { - session: Arc, - cached: Arc>>, -} - -impl Queue { - pub fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - { - let mut cached = self - .cached - .lock() - .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; - if cached.len() > 0 { - let res = match io::copy(&mut cached.as_slice(), &mut buf) { - Ok(n) => Poll::Ready(Ok(n as usize)), - Err(e) => Poll::Ready(Err(e)), - }; - cached.clear(); - return res; - } - } - let reader_session = self.session.clone(); - match reader_session.try_receive() { - Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), - Ok(Some(packet)) => match io::copy(&mut packet.bytes(), &mut buf) { - Ok(n) => Poll::Ready(Ok(n as usize)), - Err(e) => Poll::Ready(Err(e)), - }, - Ok(None) => { - let waker = cx.waker().clone(); - let cached = self.cached.clone(); - thread::spawn(move || { - match reader_session.receive_blocking() { - Ok(packet) => { - if let Ok(mut cached) = cached.lock() { - cached.extend_from_slice(packet.bytes()); - } else { - log::error!("cached lock error in wintun reciever thread, packet will be dropped"); - } - } - Err(e) => log::error!("receive_blocking error: {:?}", e), - } - waker.wake() - }); - Poll::Pending - } - } - } - - #[allow(dead_code)] - fn try_read(&mut self, mut buf: &mut [u8]) -> io::Result { - let reader_session = self.session.clone(); - match reader_session.try_receive() { - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), - Ok(op) => match op { - None => Ok(0), - Some(packet) => match io::copy(&mut packet.bytes(), &mut buf) { - Ok(s) => Ok(s as usize), - Err(e) => Err(e), - }, - }, - } - } -} - -impl Read for Queue { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - let reader_session = self.session.clone(); - match reader_session.receive_blocking() { - Ok(pkt) => match io::copy(&mut pkt.bytes(), &mut buf) { - Ok(n) => Ok(n as usize), - Err(e) => Err(e), - }, - Err(e) => Err(io::Error::new(io::ErrorKind::ConnectionAborted, e)), - } + Some(self) } } -impl Write for Queue { - fn write(&mut self, mut buf: &[u8]) -> io::Result { - let size = buf.len(); - match self.session.allocate_send_packet(size as u16) { - Err(e) => Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), - Ok(mut packet) => match io::copy(&mut buf, &mut packet.bytes_mut()) { - Ok(s) => { - self.session.send_packet(packet); - Ok(s as usize) - } - Err(e) => Err(e), - }, - } - } +pub struct Queue; - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} -impl Drop for Queue { - fn drop(&mut self) { - if let Err(err) = self.session.shutdown() { - log::error!("failed to shutdown session: {:?}", err); - } - } -}