Skip to content

Commit

Permalink
Adding slice get_dtype() method to get the dtype directly on slic…
Browse files Browse the repository at this point in the history
…es. (#303)
  • Loading branch information
Narsil authored Jul 31, 2023
1 parent f1e4d06 commit 1a65a3f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
20 changes: 20 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,26 @@ impl PySafeSlice {
Ok(shape)
}

/// Returns the dtype of the full underlying tensor
///
/// Returns:
/// (`str`):
/// The dtype of the full tensor
///
/// Example:
/// ```python
/// from safetensors import safe_open
///
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
/// tslice = f.get_slice("embedding")
/// dtype = tslice.get_dtype() # "F32"
/// ```
pub fn get_dtype(&self, py: Python) -> PyResult<PyObject> {
let dtype = self.info.dtype;
let dtype: PyObject = format!("{:?}", dtype).into_py(py);
Ok(dtype)
}

pub fn __getitem__(&self, slices: Slice) -> PyResult<PyObject> {
let slices: Vec<&PySlice> = match slices {
Slice::Slice(slice) => vec![slice],
Expand Down
5 changes: 4 additions & 1 deletion bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def test_cannot_serialize_shared(self):

def test_deserialization_slice(self):
with safe_open(self.local, framework="pt") as f:
tensor = f.get_slice("test")[:, :, 1:2]
_slice = f.get_slice("test")
self.assertEqual(_slice.get_shape(), [1, 2, 3])
self.assertEqual(_slice.get_dtype(), "F32")
tensor = _slice[:, :, 1:2]

self.assertEqual(
tensor.numpy().tobytes(),
Expand Down

0 comments on commit 1a65a3f

Please sign in to comment.