Skip to content

Commit

Permalink
Merge pull request #309 from hotosm/feature/s3-upload
Browse files Browse the repository at this point in the history
Feature/s3 upload
  • Loading branch information
kshitijrajsharma authored Dec 22, 2024
2 parents 2f63658 + 4db41ed commit b9d59c3
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 129 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
runs

# frontend
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

Expand Down
5 changes: 5 additions & 0 deletions backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ You can start flower to start monitoring your tasks
celery -A aiproject --broker=redis://127.0.0.1:6379/0 flower
```

## Start background tasks
```bash
python manage.py qcluster
```

## Run Tests

```
Expand Down
25 changes: 24 additions & 1 deletion backend/aiproject/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@
OSM_SECRET_KEY = env("OSM_SECRET_KEY")


# Limiter
# S3
BUCKET_NAME = env("BUCKET_NAME")
PARENT_BUCKET_FOLDER = env(
"PARENT_BUCKET_FOLDER", default="dev"
) # use prod for production
AWS_REGION = env("AWS_REGION", default="us-east-1")
AWS_ACCESS_KEY_ID = env("AWS_ACCESS_KEY_ID", default=None)
AWS_SECRET_ACCESS_KEY = env("AWS_SECRET_ACCESS_KEY", default=None)
PRESIGNED_URL_EXPIRY = env("PRESIGNED_URL_EXPIRY", default=3600)


# Limiter
EPOCHS_LIMIT = env("EPOCHS_LIMIT", default=20) ## TODO : Remove this global variable
BATCH_SIZE_LIMIT = env("BATCH_SIZE_LIMIT", default=8)

Expand Down Expand Up @@ -94,6 +103,7 @@
"drf_yasg",
"celery",
"django_celery_results",
"django_q",
]

MIDDLEWARE = [
Expand Down Expand Up @@ -215,6 +225,19 @@
) # if you don't want to use redis pass 'django-db' to use app db itself


Q_CLUSTER = {
"name": "DjangORM",
"workers": 4,
"retry": 60 * 6,
"max_retires": 1,
"recycle": 50,
"queue_limit": 50,
"timeout": 60 * 5, # number of seconds
"label": "Django Q",
"redis": CELERY_BROKER_URL,
}


AUTH_USER_MODEL = "login.OsmUser"

SWAGGER_SETTINGS = {
Expand Down
5 changes: 4 additions & 1 deletion backend/api-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ django-filter==22.1
django-cors-headers==3.13.0 # used for enabling cors when frontend is hosted on different server / origin
osm-login-python==0.0.2
celery==5.2.7
redis==4.4.0
redis>=3.5.3
django_celery_results==2.4.0
flower==1.2.0
validators==0.20.0
Expand All @@ -25,3 +25,6 @@ rasterio==1.3.8
numpy<2.0.0
mercantile==1.2.1

boto3==1.35.76

django-q==1.3.9
4 changes: 2 additions & 2 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ModelStatus(models.IntegerChoices):
name = models.CharField(max_length=50)
created_at = models.DateTimeField(auto_now_add=True)
last_modified = models.DateTimeField(auto_now=True)
description = models.TextField(max_length=500, null=True, blank=True)
description = models.TextField(max_length=4000, null=True, blank=True)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
published_training = models.PositiveIntegerField(null=True, blank=True)
status = models.IntegerField(default=-1, choices=ModelStatus.choices)
Expand Down Expand Up @@ -161,7 +161,7 @@ class ApprovedPredictions(models.Model):


class Banner(models.Model):
message = models.TextField()
message = models.TextField(max_length=500)
start_date = models.DateTimeField(default=timezone.now)
end_date = models.DateTimeField(null=True, blank=True)

Expand Down
63 changes: 44 additions & 19 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from django.utils import timezone
from predictor import download_imagery, get_start_end_download_coords

from .utils import S3Uploader

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
Expand All @@ -50,6 +52,22 @@
DEFAULT_TILE_SIZE = 256


def upload_to_s3(
path,
parent=settings.PARENT_BUCKET_FOLDER,
bucket_name=settings.BUCKET_NAME,
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
):
uploader = S3Uploader(
bucket_name=bucket_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
parent=parent,
)
return uploader.upload(path)


class print_time:
def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -224,6 +242,7 @@ def ramp_model_training(
if os.path.exists(output_path):
shutil.rmtree(output_path)
shutil.copytree(final_model_path, os.path.join(output_path, "checkpoint.tf"))

shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
model_input_image_path, os.path.join(output_path, "preprocessed", "input")
Expand All @@ -232,6 +251,19 @@ def ramp_model_training(
graph_output_path = f"{base_path}/train/graphs"
shutil.copytree(graph_output_path, os.path.join(output_path, "graphs"))

model = tf.keras.models.load_model(os.path.join(output_path, "checkpoint.tf"))

model.save(os.path.join(output_path, "checkpoint.h5"))

logger.info(model.inputs)
logger.info(model.outputs)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open(os.path.join(output_path, "checkpoint.tflite"), "wb") as f:
f.write(tflite_model)

with open(os.path.join(output_path, "labels.geojson"), "w", encoding="utf-8") as f:
f.write(json.dumps(serialized_field.data))

Expand Down Expand Up @@ -262,6 +294,11 @@ def ramp_model_training(
remove_original=True,
)
shutil.rmtree(base_path)
dir_result = upload_to_s3(
output_path,
parent=f"{settings.PARENT_BUCKET_FOLDER}/training_{training_instance.id}",
)
print(f"Uploaded to s3:{dir_result}")
training_instance.accuracy = float(final_accuracy)
training_instance.finished_at = timezone.now()
training_instance.status = "FINISHED"
Expand Down Expand Up @@ -377,30 +414,13 @@ def yolo_model_training(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
shutil.copyfile(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
# shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite"))

shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
model_input_image_path, os.path.join(output_path, "preprocessed", "input")
)
os.makedirs(os.path.join(output_path, model), exist_ok=True)

shutil.copytree(
os.path.join(yolo_data_dir, "images"),
os.path.join(output_path, model, "images"),
)
shutil.copytree(
os.path.join(yolo_data_dir, "labels"),
os.path.join(output_path, model, "labels"),
)
shutil.copyfile(
os.path.join(yolo_data_dir, "yolo_dataset.yaml"),
os.path.join(output_path, model, "yolo_dataset.yaml"),
)
shutil.copytree(
os.path.join(yolo_data_dir, "images"),
os.path.join(output_path, model, "images"),
Expand Down Expand Up @@ -457,6 +477,11 @@ def yolo_model_training(
remove_original=True,
)
shutil.rmtree(base_path)
dir_result = upload_to_s3(
output_path,
parent=f"{settings.PARENT_BUCKET_FOLDER}/training_{training_instance.id}",
)
print(f"Uploaded to s3:{dir_result}")
training_instance.accuracy = float(final_accuracy)
training_instance.finished_at = timezone.now()
training_instance.status = "FINISHED"
Expand Down Expand Up @@ -495,15 +520,15 @@ def train_model(
if training_instance.task_id is None or training_instance.task_id.strip() == "":
training_instance.task_id = train_model.request.id
training_instance.save()
log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}_log.txt")
log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log")

if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None:
raise ValueError("YOLO Home is not configured")
elif model_instance.base_model != "YOLO_V8_V1" and settings.RAMP_HOME is None:
raise ValueError("Ramp Home is not configured")

try:
with open(log_file, "w") as f:
with open(log_file, "a") as f:
# redirect stdout to the log file
sys.stdout = f
training_input_image_source, aoi_serializer, serialized_field = (
Expand Down
2 changes: 2 additions & 0 deletions backend/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FeedbackViewset,
GenerateFeedbackAOIGpxView,
GenerateGpxView,
LabelUploadView,
LabelViewSet,
ModelCentroidView,
ModelViewSet,
Expand Down Expand Up @@ -52,6 +53,7 @@
urlpatterns = [
path("", include(router.urls)),
path("label/osm/fetch/<int:aoi_id>/", RawdataApiAOIView.as_view()),
path("label/upload/<int:aoi_id>/", LabelUploadView.as_view(), name="label-upload"),
path(
"label/feedback/osm/fetch/<int:feedbackaoi_id>/",
RawdataApiFeedbackView.as_view(),
Expand Down
Loading

0 comments on commit b9d59c3

Please sign in to comment.