Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing torch version parsing. #143

Merged
merged 2 commits into from
Dec 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,23 +446,24 @@ impl Version {
}
}

fn from_string(string: String) -> Result<Self, &'static str> {
fn from_string(string: &str) -> Result<Self, String> {
let mut parts = string.split('.');
let major_str = parts.next().ok_or("Torch major version missing")?;
let minor_str = parts.next().ok_or("Torch minor version missing")?;
let patch_str = parts.next().ok_or("Torch path version missing")?;
let mut patch_parts = patch_str.split('+');
let patch_str = patch_parts.next().ok_or("Torch path version missing")?;

let major = major_str
.parse()
.map_err(|_| "Python major version not an integer")?;
let minor = minor_str
.parse()
.map_err(|_| "Python minor version not an integer")?;
let patch = patch_str
.parse()
.map_err(|_| "Python patch version not an integer")?;
let err = || format!("Could not parse torch package version {string}.");
let major_str = parts.next().ok_or_else(err)?;
let minor_str = parts.next().ok_or_else(err)?;
let patch_str = parts.next().ok_or_else(err)?;
// Patch is more complex and can be:
// - `1` a number
// - `1a0`, `1b0`, `1rc1` an alpha, beta, release candidate version
// - `1a0+git2323` from source with commit number
let patch_str: String = patch_str
.chars()
.take_while(|c| c.is_ascii_digit())
.collect();

let major = major_str.parse().map_err(|_| err())?;
let minor = minor_str.parse().map_err(|_| err())?;
let patch = patch_str.parse().map_err(|_| err())?;
Ok(Version {
major,
minor,
Expand Down Expand Up @@ -542,7 +543,7 @@ impl safe_open {

let version: String = module.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(version).map_err(exceptions::PyException::new_err)?;
Version::from_string(&version).map_err(exceptions::PyException::new_err)?;

// Untyped storage only exists for versions over 1.11.0
// Same for torch.asarray which is necessary for zero-copy tensor
Expand Down Expand Up @@ -1028,3 +1029,23 @@ fn safetensors_rust(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<safe_open>()?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn version_parse() {
let torch_version = "1.1.1";
let version = Version::from_string(torch_version).unwrap();
assert_eq!(version, Version::new(1, 1, 1));

let torch_version = "2.0.0a0+gitd1123c9";
let version = Version::from_string(torch_version).unwrap();
assert_eq!(version, Version::new(2, 0, 0));

let torch_version = "something";
let version = Version::from_string(torch_version);
assert!(version.is_err());
}
}