diff --git a/.gitignore b/.gitignore index fe210adb..426860a2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +runs + # frontend # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. diff --git a/backend/README.md b/backend/README.md index f86dbff7..3c424802 100644 --- a/backend/README.md +++ b/backend/README.md @@ -170,6 +170,11 @@ You can start flower to start monitoring your tasks celery -A aiproject --broker=redis://127.0.0.1:6379/0 flower ``` +## Start background tasks +```bash +python manage.py qcluster +``` + ## Run Tests ``` diff --git a/backend/aiproject/settings.py b/backend/aiproject/settings.py index c77b639f..082901c6 100644 --- a/backend/aiproject/settings.py +++ b/backend/aiproject/settings.py @@ -55,9 +55,18 @@ OSM_SECRET_KEY = env("OSM_SECRET_KEY") -# Limiter +# S3 +BUCKET_NAME = env("BUCKET_NAME") +PARENT_BUCKET_FOLDER = env( + "PARENT_BUCKET_FOLDER", default="dev" +) # use prod for production +AWS_REGION = env("AWS_REGION", default="us-east-1") +AWS_ACCESS_KEY_ID = env("AWS_ACCESS_KEY_ID", default=None) +AWS_SECRET_ACCESS_KEY = env("AWS_SECRET_ACCESS_KEY", default=None) +PRESIGNED_URL_EXPIRY = env("PRESIGNED_URL_EXPIRY", default=3600) +# Limiter EPOCHS_LIMIT = env("EPOCHS_LIMIT", default=20) ## TODO : Remove this global variable BATCH_SIZE_LIMIT = env("BATCH_SIZE_LIMIT", default=8) @@ -94,6 +103,7 @@ "drf_yasg", "celery", "django_celery_results", + "django_q", ] MIDDLEWARE = [ @@ -215,6 +225,19 @@ ) # if you don't want to use redis pass 'django-db' to use app db itself +Q_CLUSTER = { + "name": "DjangORM", + "workers": 4, + "retry": 60 * 6, + "max_retires": 1, + "recycle": 50, + "queue_limit": 50, + "timeout": 60 * 5, # number of seconds + "label": "Django Q", + "redis": CELERY_BROKER_URL, +} + + AUTH_USER_MODEL = "login.OsmUser" SWAGGER_SETTINGS = { diff --git a/backend/api-requirements.txt b/backend/api-requirements.txt index ca141a48..bc0b72fd 100644 --- a/backend/api-requirements.txt +++ b/backend/api-requirements.txt @@ -11,7 +11,7 @@ django-filter==22.1 django-cors-headers==3.13.0 # used for enabling cors when frontend is hosted on different server / origin osm-login-python==0.0.2 celery==5.2.7 -redis==4.4.0 +redis>=3.5.3 django_celery_results==2.4.0 flower==1.2.0 validators==0.20.0 @@ -25,3 +25,6 @@ rasterio==1.3.8 numpy<2.0.0 mercantile==1.2.1 +boto3==1.35.76 + +django-q==1.3.9 \ No newline at end of file diff --git a/backend/core/models.py b/backend/core/models.py index 3244ca83..f1e533d9 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -63,7 +63,7 @@ class ModelStatus(models.IntegerChoices): name = models.CharField(max_length=50) created_at = models.DateTimeField(auto_now_add=True) last_modified = models.DateTimeField(auto_now=True) - description = models.TextField(max_length=500, null=True, blank=True) + description = models.TextField(max_length=4000, null=True, blank=True) user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) published_training = models.PositiveIntegerField(null=True, blank=True) status = models.IntegerField(default=-1, choices=ModelStatus.choices) @@ -161,7 +161,7 @@ class ApprovedPredictions(models.Model): class Banner(models.Model): - message = models.TextField() + message = models.TextField(max_length=500) start_date = models.DateTimeField(default=timezone.now) end_date = models.DateTimeField(null=True, blank=True) diff --git a/backend/core/tasks.py b/backend/core/tasks.py index 451c3d60..f4c17099 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -35,6 +35,8 @@ from django.utils import timezone from predictor import download_imagery, get_start_end_download_coords +from .utils import S3Uploader + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -50,6 +52,22 @@ DEFAULT_TILE_SIZE = 256 +def upload_to_s3( + path, + parent=settings.PARENT_BUCKET_FOLDER, + bucket_name=settings.BUCKET_NAME, + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, +): + uploader = S3Uploader( + bucket_name=bucket_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + parent=parent, + ) + return uploader.upload(path) + + class print_time: def __init__(self, name): self.name = name @@ -224,6 +242,7 @@ def ramp_model_training( if os.path.exists(output_path): shutil.rmtree(output_path) shutil.copytree(final_model_path, os.path.join(output_path, "checkpoint.tf")) + shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed")) shutil.copytree( model_input_image_path, os.path.join(output_path, "preprocessed", "input") @@ -232,6 +251,19 @@ def ramp_model_training( graph_output_path = f"{base_path}/train/graphs" shutil.copytree(graph_output_path, os.path.join(output_path, "graphs")) + model = tf.keras.models.load_model(os.path.join(output_path, "checkpoint.tf")) + + model.save(os.path.join(output_path, "checkpoint.h5")) + + logger.info(model.inputs) + logger.info(model.outputs) + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + + with open(os.path.join(output_path, "checkpoint.tflite"), "wb") as f: + f.write(tflite_model) + with open(os.path.join(output_path, "labels.geojson"), "w", encoding="utf-8") as f: f.write(json.dumps(serialized_field.data)) @@ -262,6 +294,11 @@ def ramp_model_training( remove_original=True, ) shutil.rmtree(base_path) + dir_result = upload_to_s3( + output_path, + parent=f"{settings.PARENT_BUCKET_FOLDER}/training_{training_instance.id}", + ) + print(f"Uploaded to s3:{dir_result}") training_instance.accuracy = float(final_accuracy) training_instance.finished_at = timezone.now() training_instance.status = "FINISHED" @@ -377,11 +414,6 @@ def yolo_model_training( os.path.join(os.path.dirname(output_model_path), "best.onnx"), os.path.join(output_path, "checkpoint.onnx"), ) - shutil.copyfile( - os.path.join(os.path.dirname(output_model_path), "best.onnx"), - os.path.join(output_path, "checkpoint.onnx"), - ) - # shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite")) shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed")) shutil.copytree( @@ -389,18 +421,6 @@ def yolo_model_training( ) os.makedirs(os.path.join(output_path, model), exist_ok=True) - shutil.copytree( - os.path.join(yolo_data_dir, "images"), - os.path.join(output_path, model, "images"), - ) - shutil.copytree( - os.path.join(yolo_data_dir, "labels"), - os.path.join(output_path, model, "labels"), - ) - shutil.copyfile( - os.path.join(yolo_data_dir, "yolo_dataset.yaml"), - os.path.join(output_path, model, "yolo_dataset.yaml"), - ) shutil.copytree( os.path.join(yolo_data_dir, "images"), os.path.join(output_path, model, "images"), @@ -457,6 +477,11 @@ def yolo_model_training( remove_original=True, ) shutil.rmtree(base_path) + dir_result = upload_to_s3( + output_path, + parent=f"{settings.PARENT_BUCKET_FOLDER}/training_{training_instance.id}", + ) + print(f"Uploaded to s3:{dir_result}") training_instance.accuracy = float(final_accuracy) training_instance.finished_at = timezone.now() training_instance.status = "FINISHED" @@ -495,7 +520,7 @@ def train_model( if training_instance.task_id is None or training_instance.task_id.strip() == "": training_instance.task_id = train_model.request.id training_instance.save() - log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}_log.txt") + log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log") if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None: raise ValueError("YOLO Home is not configured") @@ -503,7 +528,7 @@ def train_model( raise ValueError("Ramp Home is not configured") try: - with open(log_file, "w") as f: + with open(log_file, "a") as f: # redirect stdout to the log file sys.stdout = f training_input_image_source, aoi_serializer, serialized_field = ( diff --git a/backend/core/urls.py b/backend/core/urls.py index 7508752d..9e92fd8b 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -16,6 +16,7 @@ FeedbackViewset, GenerateFeedbackAOIGpxView, GenerateGpxView, + LabelUploadView, LabelViewSet, ModelCentroidView, ModelViewSet, @@ -52,6 +53,7 @@ urlpatterns = [ path("", include(router.urls)), path("label/osm/fetch//", RawdataApiAOIView.as_view()), + path("label/upload//", LabelUploadView.as_view(), name="label-upload"), path( "label/feedback/osm/fetch//", RawdataApiFeedbackView.as_view(), diff --git a/backend/core/utils.py b/backend/core/utils.py index 6fda5626..e898c951 100644 --- a/backend/core/utils.py +++ b/backend/core/utils.py @@ -11,8 +11,11 @@ from xml.dom import ValidationErr from zipfile import ZipFile +import boto3 import requests +from botocore.exceptions import ClientError, NoCredentialsError from django.conf import settings +from django.http import HttpResponseRedirect from gpxpy.gpx import GPX, GPXTrack, GPXTrackSegment, GPXWaypoint from tqdm import tqdm @@ -20,6 +23,118 @@ from .serializers import FeedbackLabelSerializer, LabelSerializer +def get_s3_client(): + if settings.AWS_ACCESS_KEY_ID and settings.AWS_SECRET_ACCESS_KEY: + return boto3.client( + "s3", + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + region_name=settings.AWS_REGION, + ) + else: + return boto3.client("s3") + + +s3_client = get_s3_client() + + +def s3_object_exists(bucket_name, key): + """Check if an object exists in S3.""" + try: + s3_client.head_object(Bucket=bucket_name, Key=key) + return True + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise + + +def download_s3_file(bucket_name, s3_key): + """Generate a presigned URL for downloading a file from S3.""" + try: + presigned_url = s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket_name, "Key": s3_key}, + ExpiresIn=settings.PRESIGNED_URL_EXPIRY, + ) + return presigned_url + except ClientError as e: + return None + + +def get_s3_metadata(bucket_name, key): + """Retrieve metadata for an S3 object.""" + try: + response = s3_client.head_object(Bucket=bucket_name, Key=key) + return {"size": response.get("ContentLength")} + except Exception as e: + raise Exception(f"Error fetching metadata: {str(e)}") + + +def get_s3_directory_size_and_length(bucket_name, prefix): + """ + Get the total size and number of files for a directory in S3. + + Args: + bucket_name (str): The S3 bucket name. + prefix (str): The prefix (path) to the directory. + + Returns: + tuple: (size, length) - size in bytes, length as number of files. + """ + total_size = 0 + total_length = 0 + paginator = s3_client.get_paginator("list_objects_v2") + + for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix): + total_length += len(page.get("Contents", [])) + + total_size += sum(item["Size"] for item in page.get("Contents", [])) + + return total_size, total_length + + +def get_s3_directory(bucket_name, prefix): + """List objects in an S3 directory.""" + data = {"file": {}, "dir": {}} + paginator = s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter="/"): + for obj in page.get("Contents", []): + key = obj["Key"] + data["file"][os.path.basename(key)] = {"size": obj["Size"]} + for prefix_obj in page.get("CommonPrefixes", []): + sub_prefix = prefix_obj["Prefix"] + sub_dir_size, sub_dir_len = get_s3_directory_size_and_length( + bucket_name, sub_prefix + ) + + data["dir"][os.path.basename(sub_prefix.rstrip("/"))] = { + "size": sub_dir_size, + "len": sub_dir_len, + } + return data + + +def get_local_metadata(base_dir): + """Retrieve metadata for local files or directories.""" + data = {"file": {}, "dir": {}} + if os.path.isdir(base_dir): + for entry in os.scandir(base_dir): + if entry.is_file(): + data["file"][entry.name] = {"size": entry.stat().st_size} + elif entry.is_dir(): + subdir_size = get_dir_size(entry.path) + data["dir"][entry.name] = { + "len": sum(1 for _ in os.scandir(entry.path)), + "size": subdir_size, + } + elif os.path.isfile(base_dir): + data["file"][os.path.basename(base_dir)] = { + "size": os.path.getsize(base_dir), + } + return data + + def get_dir_size(directory): total_size = 0 for entry in os.scandir(directory): @@ -269,3 +384,69 @@ def process_geojson(geojson_file_path, aoi_id, feedback=False): f.result() print("writing to database finished") + + +class S3Uploader: + def __init__( + self, + bucket_name=None, + aws_access_key_id=None, + aws_secret_access_key=None, + parent="fair-dev", + ): + try: + if aws_access_key_id and aws_secret_access_key: + self.aws_session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + else: + self.aws_session = boto3.Session() + + self.s3_client = self.aws_session.client("s3") + self.bucket_name = bucket_name + self.parent = parent + logging.info("S3 connection initialized successfully") + except (NoCredentialsError, ClientError) as ex: + logging.error(f"S3 Connection Error: {ex}") + raise + + def upload(self, path, bucket_name=None): + if not os.path.exists(path): + raise FileNotFoundError(f"Path not found: {path}") + + bucket = bucket_name or self.bucket_name + if not bucket: + raise ValueError("Bucket name must be provided") + + try: + if os.path.isfile(path): + return self._upload_file(path, bucket) + elif os.path.isdir(path): + return self._upload_directory(path, bucket) + else: + raise ValueError("Path must be a file or directory") + except Exception as ex: + logging.error(f"Upload failed: {ex}") + raise + + def _upload_file(self, file_path, bucket_name): + s3_key = f"{self.parent}/{os.path.basename(file_path)}" + self.s3_client.upload_file(file_path, bucket_name, s3_key) + return f"s3://{bucket_name}/{s3_key}" + + def _upload_directory(self, directory_path, bucket_name): + total_files = 0 + for root, _, files in os.walk(directory_path): + for file in files: + local_path = os.path.join(root, file) + relative_path = os.path.relpath(local_path, directory_path) + relative_path = relative_path.replace("\\", "/") + s3_key = f"{self.parent}/{relative_path}" + self.s3_client.upload_file(local_path, bucket_name, s3_key) + total_files += 1 + return { + "directory_name": os.path.basename(directory_path), + "total_files_uploaded": total_files, + "s3_path": f"s3://{bucket_name}/{self.parent}/", + } diff --git a/backend/core/views.py b/backend/core/views.py index a34005e3..652996f7 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -11,6 +11,7 @@ import zipfile from datetime import datetime from tempfile import NamedTemporaryFile +from urllib.parse import quote # import tensorflow as tf from celery import current_app @@ -20,6 +21,7 @@ FileResponse, HttpResponse, HttpResponseBadRequest, + HttpResponseRedirect, StreamingHttpResponse, ) from django.shortcuts import get_object_or_404, redirect @@ -28,6 +30,7 @@ from django.views.decorators.cache import cache_page from django.views.decorators.vary import vary_on_cookie, vary_on_headers from django_filters.rest_framework import DjangoFilterBackend +from django_q.tasks import async_task from drf_yasg.utils import swagger_auto_schema from geojson2osm import geojson2osm from login.authentication import OsmAuthentication @@ -37,6 +40,7 @@ from rest_framework.decorators import api_view from rest_framework.exceptions import ValidationError from rest_framework.generics import ListAPIView +from rest_framework.parsers import FormParser, MultiPartParser from rest_framework.response import Response from rest_framework.views import APIView from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter @@ -71,7 +75,16 @@ UserSerializer, ) from .tasks import train_model -from .utils import get_dir_size, gpx_generator, process_rawdata, request_rawdata +from .utils import ( + download_s3_file, + get_dir_size, + get_local_metadata, + get_s3_directory, + gpx_generator, + process_rawdata, + request_rawdata, + s3_object_exists, +) if settings.ENABLE_PREDICTION_API: from predictor import predict @@ -159,7 +172,7 @@ def create(self, validated_data): raise ValidationError( f"Batch size can't be greater than {settings.RAMP_BATCH_SIZE_LIMIT} on this server" ) - if model.base_model in ["YOLO_V8_V1","YOLO_V8_V2"]: + if model.base_model in ["YOLO_V8_V1", "YOLO_V8_V2"]: if epochs > settings.YOLO_EPOCHS_LIMIT: raise ValidationError( @@ -363,24 +376,92 @@ def create(self, request, *args, **kwargs): aoi_id = request.data.get("aoi") geom = request.data.get("geom") - # Check if a label with the same AOI and geometry exists existing_label = Label.objects.filter(aoi=aoi_id, geom=geom).first() if existing_label: - # If it exists, update the existing label serializer = LabelSerializer(existing_label, data=request.data) else: - # If it doesn't exist, create a new label serializer = LabelSerializer(data=request.data) if serializer.is_valid(): serializer.save() - return Response( - serializer.data, status=status.HTTP_200_OK - ) # 200 for update, 201 for create + return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +class LabelUploadView(APIView): + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] + parser_classes = (MultiPartParser, FormParser) + + def post(self, request, aoi_id, *args, **kwargs): + geojson_file = request.FILES.get("geojson_file") + if geojson_file: + try: + geojson_data = json.load(geojson_file) + self.validate_geojson(geojson_data) + async_task( + "core.views.process_labels_geojson", + geojson_data, + aoi_id, + ) + return Response( + {"status": "GeoJSON file is being processed"}, + status=status.HTTP_202_ACCEPTED, + ) + except (json.JSONDecodeError, ValidationError) as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "No GeoJSON file provided"}, status=status.HTTP_400_BAD_REQUEST + ) + + def validate_geojson(self, geojson_data): + if geojson_data.get("type") != "FeatureCollection": + raise ValidationError("Invalid GeoJSON type. Expected 'FeatureCollection'.") + if "features" not in geojson_data or not isinstance( + geojson_data["features"], list + ): + raise ValidationError("Invalid GeoJSON format. 'features' must be a list.") + if not geojson_data["features"]: + raise ValidationError("GeoJSON 'features' list is empty.") + + # Validate the first feature + first_feature = geojson_data["features"][0] + if first_feature.get("type") != "Feature": + raise ValidationError("Invalid GeoJSON feature type. Expected 'Feature'.") + if "geometry" not in first_feature or "properties" not in first_feature: + raise ValidationError( + "Invalid GeoJSON feature format. 'geometry' and 'properties' are required." + ) + + # Validate the first feature with the serializer + first_feature["properties"]["aoi"] = self.kwargs.get("aoi_id") + serializer = LabelSerializer(data=first_feature) + + if not serializer.is_valid(): + raise ValidationError(serializer.errors) + + +def process_labels_geojson(geojson_data, aoi_id): + obj = get_object_or_404(AOI, id=aoi_id) + try: + obj.label_status = AOI.DownloadStatus.RUNNING + obj.save() + for feature in geojson_data["features"]: + feature["properties"]["aoi"] = aoi_id + serializer = LabelSerializer(data=feature) + if serializer.is_valid(): + serializer.save() + + obj.label_status = AOI.DownloadStatus.DOWNLOADED + obj.label_fetched = datetime.utcnow() + obj.save() + except Exception as ex: + obj.label_status = AOI.DownloadStatus.NOT_DOWNLOADED + obj.save() + logging.error(ex) + + class ApprovedPredictionsViewSet(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] @@ -466,20 +547,24 @@ def post(self, request, aoi_id, *args, **kwargs): status: Success/Failed """ obj = get_object_or_404(AOI, id=aoi_id) - try: - obj.label_status = 0 - obj.save() - file_download_url = request_rawdata(obj.geom.geojson) - process_rawdata(file_download_url, aoi_id) - obj.label_status = 1 - obj.label_fetched = datetime.utcnow() - obj.save() - return Response("Success", status=status.HTTP_201_CREATED) - except Exception as ex: - obj.label_status = -1 - obj.save() - # raise ex - return Response("OSM Fetch Failed", status=500) + async_task("core.views.process_rawdata_task", obj.geom.geojson, aoi_id) + return Response("Processing started", status=status.HTTP_202_ACCEPTED) + + +def process_rawdata_task(geom_geojson, aoi_id): + obj = get_object_or_404(AOI, id=aoi_id) + try: + obj.label_status = AOI.DownloadStatus.RUNNING + obj.save() + file_download_url = request_rawdata(geom_geojson) + process_rawdata(file_download_url, aoi_id) + obj.label_status = AOI.DownloadStatus.DOWNLOADED + obj.label_fetched = datetime.utcnow() + obj.save() + except Exception as ex: + obj.label_status = AOI.DownloadStatus.NOT_DOWNLOADED + obj.save() + raise ex @api_view(["GET"]) @@ -567,9 +652,7 @@ def run_task_status(request, run_id: str): # read the last 10 lines of the log file cmd = ["tail", "-n", str(settings.LOG_LINE_STREAM_TRUNCATE_VALUE), log_file] # print(cmd) - output = subprocess.check_output( - cmd - ).decode("utf-8") + output = subprocess.check_output(cmd).decode("utf-8") except Exception as e: output = str(e) result = { @@ -822,98 +905,46 @@ def get(self, request, feedback_aoi_id: int): class TrainingWorkspaceView(APIView): @method_decorator(cache_page(60 * 15)) - # @method_decorator(vary_on_headers("access-token")) + @method_decorator(vary_on_headers("access-token")) def get(self, request, lookup_dir): - """ - List the status of the training workspace. - - ### Returns: - - **Size**: The total size of the workspace in bytes. - - **dir/file**: The current dir/file on the lookup_dir. - - ### Workspace Structure: - By default, the training workspace is organized as follows: - - Training files are stored in the directory: `dataset{dataset_id}/output/training_{training}` - """ - - # {workspace_dir:{file_name:{size:20,type:file},dir_name:{size:20,len:4,type:dir}}} - base_dir = settings.TRAINING_WORKSPACE - if lookup_dir: - base_dir = os.path.join(base_dir, lookup_dir) - if not os.path.exists(base_dir): - return Response({"Errr:File/Dir not Found"}, status=404) - data = {"file": {}, "dir": {}} - if os.path.isdir(base_dir): - for entry in os.scandir(base_dir): - if entry.is_file(): - data["file"][entry.name] = { - "size": entry.stat().st_size, - } - elif entry.is_dir(): - subdir_size = get_dir_size(entry.path) - data["dir"][entry.name] = { - "len": sum(1 for _ in os.scandir(entry.path)), - "size": subdir_size, - } - elif os.path.isfile(base_dir): - data["file"][os.path.basename(base_dir)] = { - "size": os.path.getsize(base_dir) - } + bucket_name = settings.BUCKET_NAME + encoded_file_path = quote(lookup_dir.strip("/")) + s3_prefix = f"{settings.PARENT_BUCKET_FOLDER}/{encoded_file_path}/" + try: + data = get_s3_directory(bucket_name, s3_prefix) + except Exception as e: + return Response({"Error": str(e)}, status=500) - return Response(data, status=status.HTTP_201_CREATED) + return Response(data, status=status.HTTP_200_OK) class TrainingWorkspaceDownloadView(APIView): - authentication_classes = [OsmAuthentication] - permission_classes = [IsOsmAuthenticated] + # authentication_classes = [OsmAuthentication] + # permission_classes = [IsOsmAuthenticated] - def dispatch(self, request, *args, **kwargs): - lookup_dir = kwargs.get("lookup_dir") - if lookup_dir.endswith("training_accuracy.png"): - # bypass - self.authentication_classes = [] - self.permission_classes = [] + # def dispatch(self, request, *args, **kwargs): + # lookup_dir = kwargs.get("lookup_dir") + # if lookup_dir.endswith("training_accuracy.png"): + # # bypass + # self.authentication_classes = [] + # self.permission_classes = [] - return super().dispatch(request, *args, **kwargs) + # return super().dispatch(request, *args, **kwargs) def get(self, request, lookup_dir): - base_dir = os.path.join(settings.TRAINING_WORKSPACE, lookup_dir) - if not os.path.exists(base_dir): - return Response({"Errr: File/Dir not found"}, status=404) - size = ( - get_dir_size(base_dir) - if os.path.isdir(base_dir) - else os.path.getsize(base_dir) - ) / (1024**2) - if ( - size > settings.TRAINING_WORKSPACE_DOWNLOAD_LIMIT - ): # if file is greater than 200 mb exit - return Response( - { - f"Errr: File Size {size} MB Exceed More than {settings.TRAINING_WORKSPACE_DOWNLOAD_LIMIT} MB" - }, - status=403, - ) + s3_key = os.path.join(settings.PARENT_BUCKET_FOLDER, lookup_dir) + bucket_name = settings.BUCKET_NAME - if os.path.isfile(base_dir): - response = FileResponse(open(base_dir, "rb")) - response["Content-Disposition"] = 'attachment; filename="{}"'.format( - os.path.basename(base_dir) - ) - return response + if not s3_object_exists(bucket_name, s3_key): + return Response("File not found in S3", status=404) + presigned_url = download_s3_file(bucket_name, s3_key) + # ?url_only=true + url_only = request.query_params.get("url_only", "false").lower() == "true" + + if url_only: + return Response({"result": presigned_url}) else: - # TODO : This will take time to zip also based on the reading/writing speed of the dir - temp = NamedTemporaryFile() - shutil.make_archive(temp.name, "zip", base_dir) - # rewind the file so it can be read from the beginning - temp.seek(0) - response = StreamingHttpResponse( - open(temp.name + ".zip", "rb").read(), content_type="application/zip" - ) - response["Content-Disposition"] = 'attachment; filename="{}.zip"'.format( - os.path.basename(base_dir) - ) - return response + return HttpResponseRedirect(presigned_url) class BannerViewSet(viewsets.ModelViewSet):