Skip to content

Commit

Permalink
Merge pull request #251 from hotosm/fix/dataset-size
Browse files Browse the repository at this point in the history
Bug Fix : Dataset Size
  • Loading branch information
kshitijrajsharma authored May 30, 2024
2 parents bbb55d7 + e784173 commit d866e42
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
1 change: 1 addition & 0 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Training(models.Model):
finished_at = models.DateTimeField(null=True, blank=True)
accuracy = models.FloatField(null=True, blank=True)
epochs = models.PositiveIntegerField()
chips_length = models.PositiveIntegerField(default=0)
batch_size = models.PositiveIntegerField()
freeze_layers = models.BooleanField(default=False)

Expand Down
58 changes: 43 additions & 15 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,14 @@
import os
import shutil
import sys
import tarfile
import traceback
from shutil import rmtree
import tarfile

import hot_fair_utilities
import ramp.utils
import tensorflow as tf
from celery import shared_task
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone
from hot_fair_utilities import preprocess, train
from hot_fair_utilities.training import run_feedback
from predictor import download_imagery, get_start_end_download_coords

from core.models import AOI, Feedback, FeedbackAOI, FeedbackLabel, Label, Training
from core.serializers import (
AOISerializer,
Expand All @@ -29,6 +20,14 @@
LabelFileSerializer,
)
from core.utils import bbox, is_dir_empty
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone
from hot_fair_utilities import preprocess, train
from hot_fair_utilities.training import run_feedback
from predictor import download_imagery, get_start_end_download_coords

logger = logging.getLogger(__name__)

Expand All @@ -37,6 +36,7 @@

DEFAULT_TILE_SIZE = 256


def xz_folder(folder_path, output_filename, remove_original=False):
"""
Compresses a folder and its contents into a .tar.xz file and optionally removes the original folder.
Expand All @@ -47,8 +47,8 @@ def xz_folder(folder_path, output_filename, remove_original=False):
- remove_original: If True, the original folder is removed after compression.
"""

if not output_filename.endswith('.tar.xz'):
output_filename += '.tar.xz'
if not output_filename.endswith(".tar.xz"):
output_filename += ".tar.xz"

with tarfile.open(output_filename, "w:xz") as tar:
tar.add(folder_path, arcname=os.path.basename(folder_path))
Expand All @@ -57,6 +57,20 @@ def xz_folder(folder_path, output_filename, remove_original=False):
shutil.rmtree(folder_path)


def get_file_count(path):
try:
return len(
[
entry
for entry in os.listdir(path)
if os.path.isfile(os.path.join(path, entry))
]
)
except Exception as e:
print(f"An error occurred: {e}")
return 0


@shared_task
def train_model(
dataset_id,
Expand Down Expand Up @@ -189,6 +203,10 @@ def train_model(
rasterize_options=["binary"],
georeference_images=True,
)
training_instance.chips_length = get_file_count(
os.path.join(preprocess_output, "chips")
)
training_instance.save()

# train

Expand Down Expand Up @@ -272,9 +290,19 @@ def train_model(
f.write(json.dumps(aoi_serializer.data))

# copy aois and labels to preprocess output before compressing it to tar
shutil.copyfile(os.path.join(output_path, "aois.geojson"), os.path.join(preprocess_output,'aois.geojson'))
shutil.copyfile(os.path.join(output_path, "labels.geojson"), os.path.join(preprocess_output,'labels.geojson'))
xz_folder(preprocess_output, os.path.join(output_path, "preprocessed.tar.xz"), remove_original=True)
shutil.copyfile(
os.path.join(output_path, "aois.geojson"),
os.path.join(preprocess_output, "aois.geojson"),
)
shutil.copyfile(
os.path.join(output_path, "labels.geojson"),
os.path.join(preprocess_output, "labels.geojson"),
)
xz_folder(
preprocess_output,
os.path.join(output_path, "preprocessed.tar.xz"),
remove_original=True,
)

# now remove the ramp-data all our outputs are copied to our training workspace
shutil.rmtree(base_path)
Expand Down
11 changes: 3 additions & 8 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@
ModelSerializer,
PredictionParamSerializer,
)
# from .tasks import train_model
from celery import Celery
from .tasks import train_model
from .utils import get_dir_size, gpx_generator, process_rawdata, request_rawdata


Expand Down Expand Up @@ -129,10 +128,8 @@ def create(self, validated_data):
# create the model instance
instance = Training.objects.create(**validated_data)

celery = Celery()

# run your function here
task = celery.train_model.delay(
task = train_model.delay(
dataset_id=instance.model.dataset.id,
training_id=instance.id,
epochs=instance.epochs,
Expand Down Expand Up @@ -474,9 +471,7 @@ def post(self, request, *args, **kwargs):
batch_size=batch_size,
source_imagery=training_instance.source_imagery,
)
celery = Celery()

task = celery.train_model.delay(
task = train_model.delay(
dataset_id=instance.model.dataset.id,
training_id=instance.id,
epochs=instance.epochs,
Expand Down

0 comments on commit d866e42

Please sign in to comment.