diff --git a/src/shared_socket.rs b/src/shared_socket.rs index 993d8b888..e11196eec 100644 --- a/src/shared_socket.rs +++ b/src/shared_socket.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use log::debug; use socket2::{Domain, Protocol, Socket, Type}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; #[pyclass] #[derive(Debug)] @@ -13,9 +13,14 @@ pub struct SocketHeld { #[pymethods] impl SocketHeld { #[new] - pub fn new(address: String, port: i32) -> PyResult { - let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; - let address: SocketAddr = format!("{}:{}", address, port).parse()?; + pub fn new(ip: String, port: u16) -> PyResult { + let ip: IpAddr = ip.parse()?; + let socket = if ip.is_ipv4() { + Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))? + } else { + Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))? + }; + let address = SocketAddr::new(ip, port); debug!("{}", address); // reuse port is not available on windows #[cfg(not(target_os = "windows"))]