diff --git a/SpatialDE/_internal/tf_dataset.py b/SpatialDE/_internal/tf_dataset.py index ea441b4..4284f41 100644 --- a/SpatialDE/_internal/tf_dataset.py +++ b/SpatialDE/_internal/tf_dataset.py @@ -30,11 +30,10 @@ def __init__( def __call__(self): for i, g in enumerate(self.genes): - slice = self.adata[:, i] if self.layer is None: - data = slice.X + data = self.adata.X[:, i] else: - data = slice.layers[self.layer] + data = self.adata.layers[self.layer][:, i] if issparse(data): data = data.toarray() with tf.device(tf.DeviceSpec(device_type="CPU").to_string()):