Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gagewrye committed Nov 23, 2024
1 parent c6f4ef2 commit 3f012b7
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 1,836 deletions.
2 changes: 2 additions & 0 deletions DroneClassification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from data import MemmapDataset
from models import ResNet_UNet, LandmassLoss, JaccardLoss
5 changes: 3 additions & 2 deletions DroneClassification/data/tiff_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from rasterio.features import rasterize
from shapely.validation import make_valid
import numpy as np
from typing import Tuple

def tile_tiff_pair(chunk_path: str, image_size=128) -> tuple[np.ndarray, np.ndarray]:
def tile_tiff_pair(chunk_path: str, image_size=128) -> Tuple[np.ndarray, np.ndarray]:
name = chunk_path.split('/')[-1]
print(f"Processing {name}...")

Expand Down Expand Up @@ -120,7 +121,7 @@ def tile_generator(data, tile_size):
if i + tile_size <= nrows and j + tile_size <= ncols:
yield data[:, i:i+tile_size, j:j+tile_size], (i, j)

def create_pairs(rgb_data, label_data, tile_size) -> tuple[np.ndarray, np.ndarray]:
def create_pairs(rgb_data, label_data, tile_size) -> Tuple[np.ndarray, np.ndarray]:
images = []
labels = []
for (rgb_tile, _), (label_tile, _) in zip(tile_generator(rgb_data, tile_size), tile_generator(label_data, tile_size)):
Expand Down
2 changes: 1 addition & 1 deletion DroneClassification/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .models import ResNet_UNet, ResNet_FC, DenseNet_UNet, SegmentModelWrapper
from .models import ResNet_UNet, ResNet_FC, SegmentModelWrapper
from .loss import JaccardLoss, FocalLoss, DistanceCountLoss, LandmassLoss
Loading

0 comments on commit 3f012b7

Please sign in to comment.