diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 7f065ad7..6d0e85ee 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -446,23 +446,24 @@ impl Version { } } - fn from_string(string: String) -> Result { + fn from_string(string: &str) -> Result { let mut parts = string.split('.'); - let major_str = parts.next().ok_or("Torch major version missing")?; - let minor_str = parts.next().ok_or("Torch minor version missing")?; - let patch_str = parts.next().ok_or("Torch path version missing")?; - let mut patch_parts = patch_str.split('+'); - let patch_str = patch_parts.next().ok_or("Torch path version missing")?; - - let major = major_str - .parse() - .map_err(|_| "Python major version not an integer")?; - let minor = minor_str - .parse() - .map_err(|_| "Python minor version not an integer")?; - let patch = patch_str - .parse() - .map_err(|_| "Python patch version not an integer")?; + let err = || format!("Could not parse torch package version {string}."); + let major_str = parts.next().ok_or_else(err)?; + let minor_str = parts.next().ok_or_else(err)?; + let patch_str = parts.next().ok_or_else(err)?; + // Patch is more complex and can be: + // - `1` a number + // - `1a0`, `1b0`, `1rc1` an alpha, beta, release candidate version + // - `1a0+git2323` from source with commit number + let patch_str: String = patch_str + .chars() + .take_while(|c| c.is_ascii_digit()) + .collect(); + + let major = major_str.parse().map_err(|_| err())?; + let minor = minor_str.parse().map_err(|_| err())?; + let patch = patch_str.parse().map_err(|_| err())?; Ok(Version { major, minor, @@ -542,7 +543,7 @@ impl safe_open { let version: String = module.getattr(intern!(py, "__version__"))?.extract()?; let version = - Version::from_string(version).map_err(exceptions::PyException::new_err)?; + Version::from_string(&version).map_err(exceptions::PyException::new_err)?; // Untyped storage only exists for versions over 1.11.0 // Same for torch.asarray which is necessary for zero-copy tensor @@ -1028,3 +1029,23 @@ fn safetensors_rust(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn version_parse() { + let torch_version = "1.1.1"; + let version = Version::from_string(torch_version).unwrap(); + assert_eq!(version, Version::new(1, 1, 1)); + + let torch_version = "2.0.0a0+gitd1123c9"; + let version = Version::from_string(torch_version).unwrap(); + assert_eq!(version, Version::new(2, 0, 0)); + + let torch_version = "something"; + let version = Version::from_string(torch_version); + assert!(version.is_err()); + } +}