Skip to content

Commit

Permalink
Fixing torch version parsing. (#143)
Browse files Browse the repository at this point in the history
* Fixing torch version parsing.

* Anticipating a bit more variations of this.
  • Loading branch information
Narsil authored Dec 29, 2022
1 parent ba8de5f commit 6cd64c7
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,23 +446,24 @@ impl Version {
}
}

fn from_string(string: String) -> Result<Self, &'static str> {
fn from_string(string: &str) -> Result<Self, String> {
let mut parts = string.split('.');
let major_str = parts.next().ok_or("Torch major version missing")?;
let minor_str = parts.next().ok_or("Torch minor version missing")?;
let patch_str = parts.next().ok_or("Torch path version missing")?;
let mut patch_parts = patch_str.split('+');
let patch_str = patch_parts.next().ok_or("Torch path version missing")?;

let major = major_str
.parse()
.map_err(|_| "Python major version not an integer")?;
let minor = minor_str
.parse()
.map_err(|_| "Python minor version not an integer")?;
let patch = patch_str
.parse()
.map_err(|_| "Python patch version not an integer")?;
let err = || format!("Could not parse torch package version {string}.");
let major_str = parts.next().ok_or_else(err)?;
let minor_str = parts.next().ok_or_else(err)?;
let patch_str = parts.next().ok_or_else(err)?;
// Patch is more complex and can be:
// - `1` a number
// - `1a0`, `1b0`, `1rc1` an alpha, beta, release candidate version
// - `1a0+git2323` from source with commit number
let patch_str: String = patch_str
.chars()
.take_while(|c| c.is_ascii_digit())
.collect();

let major = major_str.parse().map_err(|_| err())?;
let minor = minor_str.parse().map_err(|_| err())?;
let patch = patch_str.parse().map_err(|_| err())?;
Ok(Version {
major,
minor,
Expand Down Expand Up @@ -542,7 +543,7 @@ impl safe_open {

let version: String = module.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(version).map_err(exceptions::PyException::new_err)?;
Version::from_string(&version).map_err(exceptions::PyException::new_err)?;

// Untyped storage only exists for versions over 1.11.0
// Same for torch.asarray which is necessary for zero-copy tensor
Expand Down Expand Up @@ -1028,3 +1029,23 @@ fn safetensors_rust(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<safe_open>()?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn version_parse() {
let torch_version = "1.1.1";
let version = Version::from_string(torch_version).unwrap();
assert_eq!(version, Version::new(1, 1, 1));

let torch_version = "2.0.0a0+gitd1123c9";
let version = Version::from_string(torch_version).unwrap();
assert_eq!(version, Version::new(2, 0, 0));

let torch_version = "something";
let version = Version::from_string(torch_version);
assert!(version.is_err());
}
}

0 comments on commit 6cd64c7

Please sign in to comment.