diff --git a/API/api_worker.py b/API/api_worker.py index de0b23c3..9a8ba6a6 100644 --- a/API/api_worker.py +++ b/API/api_worker.py @@ -25,6 +25,7 @@ DEFAULT_SOFT_TASK_LIMIT, ENABLE_SOZIP, ENABLE_TILES, + EXPORT_PATH, HDX_HARD_TASK_LIMIT, HDX_SOFT_TASK_LIMIT, ) @@ -273,11 +274,38 @@ def process_raw_data(self, params, user=None): raise ex +class BaseclassTask(celery.Task): + """Base class for celery tasks + + Args: + celery (_type_): _description_ + """ + + def on_failure(self, exc, task_id, args, kwargs, einfo): + """Logic when task fails + + Args: + exc (_type_): _description_ + task_id (_type_): _description_ + args (_type_): _description_ + kwargs (_type_): _description_ + einfo (_type_): _description_ + """ + # exc (Exception) - The exception raised by the task. + # args (Tuple) - Original arguments for the task that failed. + # kwargs (Dict) - Original keyword arguments for the task that failed. + print("{0!r} failed: {1!r}".format(task_id, exc)) + clean_dir = os.path.join(EXPORT_PATH, task_id) + if os.path.exists(clean_dir): + shutil.rmtree(clean_dir) + + @celery.task( bind=True, name="process_custom_request", time_limit=HDX_HARD_TASK_LIMIT, soft_time_limit=HDX_SOFT_TASK_LIMIT, + base=BaseclassTask, ) def process_custom_request(self, params, user=None): if self.request.retries > 0: @@ -286,7 +314,7 @@ def process_custom_request(self, params, user=None): if not params.dataset: params.dataset = DatasetConfig() - custom_object = CustomExport(params) + custom_object = CustomExport(params, uid=self.request.id) try: return custom_object.process_custom_categories() except Exception as ex: diff --git a/src/app.py b/src/app.py index c159bd5a..0c9227ab 100644 --- a/src/app.py +++ b/src/app.py @@ -1234,7 +1234,7 @@ class CustomExport: - params (DynamicCategoriesModel): An instance of DynamicCategoriesModel containing configuration settings. """ - def __init__(self, params): + def __init__(self, params, uid=None): self.params = params self.iso3 = self.params.iso3 self.HDX_SUPPORTED_FORMATS = ["geojson", "gpkg", "kml", "shp"] @@ -1264,8 +1264,10 @@ def __init__(self, params): self.params.dataset.dataset_prefix = dataset_prefix if not self.params.dataset.dataset_locations: self.params.dataset.dataset_locations = json.loads(dataset_locations) + self.uuid = uid + if self.uuid is None: + self.uuid = str(uuid.uuid4().hex) - self.uuid = str(uuid.uuid4().hex) self.parallel_process_state = False self.default_export_base_name = ( self.iso3.upper() if self.iso3 else self.params.dataset.dataset_prefix