From c6f4ef279cb4303e7a0e4d158de06695fb6be8f5 Mon Sep 17 00:00:00 2001 From: gagewrye <95107220+gagewrye@users.noreply.github.com> Date: Sat, 23 Nov 2024 15:26:28 -0800 Subject: [PATCH] Typing error fix --- DroneClassification/data/MemoryMapDataset.py | 10 +++++----- DroneClassification/data/prepare_data.ipynb | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/DroneClassification/data/MemoryMapDataset.py b/DroneClassification/data/MemoryMapDataset.py index 6554228..ef48c97 100755 --- a/DroneClassification/data/MemoryMapDataset.py +++ b/DroneClassification/data/MemoryMapDataset.py @@ -1,10 +1,10 @@ from torch.utils.data import Dataset, Sampler, BatchSampler, DataLoader import numpy as np -from typing import Optional +from typing import Optional, List, Tuple import torch class MemmapDataset(Dataset): - def __init__(self, images: np.ndarray, labels: np.ndarray, validation_indices: Optional[tuple] = None,): + def __init__(self, images: np.ndarray, labels: np.ndarray, validation_indices: Optional[Tuple] = None,): """ Inputs are expected to be memory mapped numpy arrays (.npy) @@ -30,7 +30,7 @@ def __init__(self, images: np.ndarray, labels: np.ndarray, validation_indices: O def __len__(self) -> int: return self.images.shape[0] - def __getitem__(self, idx) -> tuple: + def __getitem__(self, idx) -> Tuple: image = self.images[idx] label = self.labels[idx] @@ -58,7 +58,7 @@ def split(self, split_ratio: float): return train_dataset, val_dataset - def split_into_folds(self, num_folds: int) -> list[Dataset]: + def split_into_folds(self, num_folds: int) -> List[Dataset]: """ Creates a list of validation datasets for cross validation. The original dataset will be used as the training dataset. @@ -92,7 +92,7 @@ class SliceSampler(Sampler): Takes slices of the dataset to minimize overhead of accessing a memory mapped array. Can optionally skip indices to allow for cross validation with memory mapping. """ - def __init__(self, dataset_len, batch_size, skip_indices: Optional[tuple] = None): + def __init__(self, dataset_len, batch_size, skip_indices: Optional[Tuple] = None): self.dataset_len = dataset_len self.batch_size = batch_size self.start_skip = None diff --git a/DroneClassification/data/prepare_data.ipynb b/DroneClassification/data/prepare_data.ipynb index f91f66c..ac00199 100755 --- a/DroneClassification/data/prepare_data.ipynb +++ b/DroneClassification/data/prepare_data.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -239,7 +239,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mangrove", + "display_name": "i2sb", "language": "python", "name": "python3" }, @@ -253,7 +253,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.8.18" }, "orig_nbformat": 4 },