diff --git a/aio/aio-proxy/aio_proxy/decorators/http_exception.py b/aio/aio-proxy/aio_proxy/decorators/http_exception.py index cd2b75d2..fd7345f3 100755 --- a/aio/aio-proxy/aio_proxy/decorators/http_exception.py +++ b/aio/aio-proxy/aio_proxy/decorators/http_exception.py @@ -13,15 +13,12 @@ def http_exception_handler(func): def inner_function(*args, **kwargs): try: return func(*args, **kwargs) - except ( - elasticsearch.exceptions.RequestError, - TypeError, - ) as error: + except elasticsearch.exceptions.RequestError as error: raise web.HTTPBadRequest( text=serialize_error_text(str(error)), content_type="application/json", ) - except ValueError as error: + except (ValueError, TypeError) as error: with push_scope() as scope: # group value errors together based on their response (Bad request) scope.fingerprint = ["Bad Request"] diff --git a/aio/aio-proxy/aio_proxy/main.py b/aio/aio-proxy/aio_proxy/main.py index 7707cb8d..c694ea63 100755 --- a/aio/aio-proxy/aio_proxy/main.py +++ b/aio/aio-proxy/aio_proxy/main.py @@ -4,9 +4,7 @@ import aiohttp from aiohttp import web from aiohttp_swagger3 import ReDocUiSettings, SwaggerDocs -from elasticapm.contrib.aiohttp import ElasticAPM -from aio_proxy.response.helpers import APM_URL, CURRENT_ENV from aio_proxy.routes import routes from aio_proxy.settings import config @@ -30,12 +28,14 @@ def main(): components=open_api_path, ) app["config"] = config + """ app["ELASTIC_APM"] = { "SERVICE_NAME": "SEARCH APM", "SERVER_URL": APM_URL, "ELASTIC_APM_ENVIRONMENT": CURRENT_ENV, } ElasticAPM(app) + """ web.run_app(app, host=config["host"], port=config["port"]) app.on_startup.append(swagger) diff --git a/aio/aio-proxy/aio_proxy/request/parsers/empty_params.py b/aio/aio-proxy/aio_proxy/request/parsers/empty_params.py index 6a46b7cb..d72207e1 100644 --- a/aio/aio-proxy/aio_proxy/request/parsers/empty_params.py +++ b/aio/aio-proxy/aio_proxy/request/parsers/empty_params.py @@ -3,15 +3,14 @@ @value_exception_handler(error="Veuillez indiquer au moins un paramètre de recherche.") def check_empty_params(parameters): - # If all parameters are empty (except matching size because it always has a - # default value) raise value error - empty_parameters = all( - val is None - for val in [ - param_value - for param, param_value in vars(parameters).items() - if param not in ["page", "per_page", "matching_size"] - ] - ) - if empty_parameters: + # If all parameters are empty (except matching size and pagination + # because they always have a default value) raise value error + # Check if all non-default parameters are empty, raise a ValueError if they are + non_default_params = [ + param_value + for param, param_value in parameters.items() + if param not in ["page", "per_page", "matching_size"] + ] + + if all(val is None for val in non_default_params): raise ValueError diff --git a/aio/aio-proxy/aio_proxy/request/parsers/page.py b/aio/aio-proxy/aio_proxy/request/parsers/page.py index 35512d7e..23da5287 100644 --- a/aio/aio-proxy/aio_proxy/request/parsers/page.py +++ b/aio/aio-proxy/aio_proxy/request/parsers/page.py @@ -1,27 +1,3 @@ -from aio_proxy.decorators.value_exception import value_exception_handler - -MAX_PAGE_VALUE = 1000 -MIN_PAGE_NUMBER = 0 - - -@value_exception_handler( - error="Veuillez indiquer un numéro de page entier entre 1 et 1000, par défaut 1." -) -def parse_and_validate_page(request) -> int: - """Extract and Check the validity of page number. - - Args: - request: HTTP request. - - Returns: - page(int) if valid. - default 1. - - Raises: - ValueError: if page is not integer, lower than 1 or higher than 1000. - """ - page = int(request.rel_url.query.get("page", 1)) - 1 # default 1 - # 1000 is elasticsearch's default page limit - if page <= MIN_PAGE_NUMBER - 1 or page >= MAX_PAGE_VALUE: - raise ValueError - return page +def parse_int(request, param) -> int: + integer = int(request.rel_url.query.get(param)) + return integer diff --git a/aio/aio-proxy/aio_proxy/request/parsers/per_page.py b/aio/aio-proxy/aio_proxy/request/parsers/per_page.py index eaa6b318..b2b317e5 100644 --- a/aio/aio-proxy/aio_proxy/request/parsers/per_page.py +++ b/aio/aio-proxy/aio_proxy/request/parsers/per_page.py @@ -5,18 +5,6 @@ error="Veuillez indiquer un `per_page` entre 1 et 25, par défaut 10." ) def parse_and_validate_per_page(request) -> int: - """Extract and Check the validity of per page. - - Args: - request: HTTP request. - - Returns: - per_page(int) if valid. - default 10. - - Raises: - ValueError: if per_page is not integer. - """ per_page = int(request.rel_url.query.get("per_page", 10)) # default 10 # Limit number of results per page for performance reasons max_per_page = 25 diff --git a/aio/aio-proxy/aio_proxy/request/parsers/string_parser.py b/aio/aio-proxy/aio_proxy/request/parsers/string_parser.py index 6533d964..f443aa49 100644 --- a/aio/aio-proxy/aio_proxy/request/parsers/string_parser.py +++ b/aio/aio-proxy/aio_proxy/request/parsers/string_parser.py @@ -1,4 +1,4 @@ -def clean_parameter(request, param: str): +def clean_parameter(param: str): """Extract and clean param from request. Remove white spaces and use upper case. @@ -10,16 +10,13 @@ def clean_parameter(request, param: str): None if None. clean_param otherwise. """ - param = parse_parameter(request, param) if param is None: return None + param = param.replace("-", " ") param_clean = param.replace(" ", "").upper() return param_clean -def parse_parameter(request, param: str, default_value=None): - param = request.rel_url.query.get(param, default_value) - if param is None: - return None +def clean_name(param: str): param = param.replace("-", " ").lower() return param diff --git a/aio/aio-proxy/aio_proxy/request/parsers/terms.py b/aio/aio-proxy/aio_proxy/request/parsers/terms.py index 127719d1..d7feb36c 100644 --- a/aio/aio-proxy/aio_proxy/request/parsers/terms.py +++ b/aio/aio-proxy/aio_proxy/request/parsers/terms.py @@ -1,21 +1,3 @@ -def parse_and_validate_terms(request, default_value=None): - """Extract search terms from request. - - Args: - request: HTTP request. - default_value: - - Returns: - terms if given. - Raises: - ValueError: otherwise. - """ - terms = request.rel_url.query.get("q", default_value) - if terms: - return terms.upper() - return terms - - def check_short_terms_and_no_param(search_params): """Prevent performance issues by refusing query terms less than 3 characters. Accept less than 3 characters if at least one parameter is filled. @@ -29,13 +11,13 @@ def check_short_terms_and_no_param(search_params): """ min_chars_in_terms = 3 if ( - search_params.terms is not None - and len(search_params.terms) < min_chars_in_terms + search_params.get("terms", None) is not None + and len(search_params.get("terms", None)) < min_chars_in_terms and all( val is None for val in [ param_value - for param, param_value in vars(search_params).items() + for param, param_value in search_params.items() if param not in ["terms", "page", "per_page", "matching_size"] ] ) diff --git a/aio/aio-proxy/aio_proxy/request/search_params_builder.py b/aio/aio-proxy/aio_proxy/request/search_params_builder.py index 3d8192fa..7640aa9f 100644 --- a/aio/aio-proxy/aio_proxy/request/search_params_builder.py +++ b/aio/aio-proxy/aio_proxy/request/search_params_builder.py @@ -1,212 +1,66 @@ from aio_proxy.request.helpers import validate_date_range -from aio_proxy.request.parsers.activite_principale import ( - validate_activite_principale, -) from aio_proxy.request.parsers.ban_params import ban_params -from aio_proxy.request.parsers.bool_fields import parse_and_validate_bool_field -from aio_proxy.request.parsers.categorie_entreprise import ( - validate_categorie_entreprise, -) -from aio_proxy.request.parsers.code_commune import validate_code_commune -from aio_proxy.request.parsers.code_postal import validate_code_postal -from aio_proxy.request.parsers.collectivite_territoriale import ( - validate_code_collectivite_territoriale, -) -from aio_proxy.request.parsers.convention_collective import ( - validate_id_convention_collective, -) -from aio_proxy.request.parsers.date_parser import parse_and_validate_date -from aio_proxy.request.parsers.departement import validate_departement from aio_proxy.request.parsers.empty_params import check_empty_params -from aio_proxy.request.parsers.etat_administratif import ( - validate_etat_administratif, -) -from aio_proxy.request.parsers.finess import validate_id_finess -from aio_proxy.request.parsers.insee_bool import match_bool_to_insee_value -from aio_proxy.request.parsers.int_parser import parse_and_validate_int -from aio_proxy.request.parsers.latitude import parse_and_validate_latitude -from aio_proxy.request.parsers.longitude import parse_and_validate_longitude -from aio_proxy.request.parsers.matching_size import ( - parse_and_validate_matching_size, -) -from aio_proxy.request.parsers.nature_juridique import validate_nature_juridique -from aio_proxy.request.parsers.page import parse_and_validate_page -from aio_proxy.request.parsers.per_page import parse_and_validate_per_page -from aio_proxy.request.parsers.radius import parse_and_validate_radius -from aio_proxy.request.parsers.region import validate_region -from aio_proxy.request.parsers.rge import validate_id_rge -from aio_proxy.request.parsers.section_activite_principale import ( - validate_section_activite_principale, -) from aio_proxy.request.parsers.selected_fields import ( validate_inclusion_fields, - validate_selected_fields, -) -from aio_proxy.request.parsers.string_parser import ( - clean_parameter, - parse_parameter, ) -from aio_proxy.request.parsers.terms import ( - check_short_terms_and_no_param, - parse_and_validate_terms, -) -from aio_proxy.request.parsers.tranche_effectif import ( - validate_tranche_effectif_salarie, -) -from aio_proxy.request.parsers.type_personne import validate_type_personne -from aio_proxy.request.parsers.uai import validate_id_uai +from aio_proxy.request.parsers.terms import check_short_terms_and_no_param from aio_proxy.request.search_params_model import SearchParams from aio_proxy.request.search_type import SearchType -from aio_proxy.utils.utils import str_to_list class SearchParamsBuilder: """This class extracts parameter values from request and saves them in a SearchParams dataclass object.""" + PARAMETER_MAPPING = { + "q": "terms", + "limite_matching_etablissements": "matching_size", + "nature_juridique": "nature_juridique_unite_legale", + "date_naissance_personne_min": "min_date_naiss_personne", + "date_naissance_personne_max": "max_date_naiss_personne", + "etat_administratif": "etat_administratif_unite_legale", + "activite_principale": "activite_principale_unite_legale", + "code_commune": "commune", + "tranche_effectif_salarie": "tranche_effectif_salarie_unite_legale", + "est_ess": "economie_sociale_solidaire_unite_legale", + "long": "lon", + } + + @staticmethod + def map_request_parameters(request): + # Extract all query parameters from the request + request_params = request.rel_url.query + + # Include parameters specified in PARAMETER_MAPPING with mapping + mapped_params = { + model_field: request_params.get(http_param, None) + for http_param, model_field in SearchParamsBuilder.PARAMETER_MAPPING.items() + if request_params.get(http_param, None) is not None + } + # Include parameters not specified in PARAMETER_MAPPING without mapping + for key, value in request_params.items(): + if key not in SearchParamsBuilder.PARAMETER_MAPPING: + mapped_params[key] = value + return mapped_params + @staticmethod def get_text_search_params(request): - params = SearchParams( - page=parse_and_validate_page(request), - per_page=parse_and_validate_per_page(request), - terms=parse_and_validate_terms(request), - matching_size=parse_and_validate_matching_size(request), - nature_juridique_unite_legale=validate_nature_juridique( - str_to_list(clean_parameter(request, param="nature_juridique")) - ), - id_rge=validate_id_rge(clean_parameter(request, param="id_rge")), - nom_personne=parse_parameter(request, param="nom_personne"), - prenoms_personne=parse_parameter(request, param="prenoms_personne"), - min_date_naiss_personne=parse_and_validate_date( - request, param="date_naissance_personne_min" - ), - max_date_naiss_personne=parse_and_validate_date( - request, param="date_naissance_personne_max" - ), - ca_min=parse_and_validate_int(request, param="ca_min"), - ca_max=parse_and_validate_int(request, param="ca_max"), - resultat_net_min=parse_and_validate_int(request, param="resultat_net_min"), - resultat_net_max=parse_and_validate_int(request, param="resultat_net_max"), - type_personne=validate_type_personne( - clean_parameter(request, param="type_personne") - ), - etat_administratif_unite_legale=validate_etat_administratif( - clean_parameter(request, param="etat_administratif") - ), - activite_principale_unite_legale=validate_activite_principale( - str_to_list(clean_parameter(request, param="activite_principale")) - ), - categorie_entreprise=validate_categorie_entreprise( - str_to_list(clean_parameter(request, param="categorie_entreprise")) - ), - commune=validate_code_commune( - str_to_list(clean_parameter(request, param="code_commune")) - ), - code_postal=validate_code_postal( - str_to_list(clean_parameter(request, param="code_postal")) - ), - departement=validate_departement( - str_to_list(clean_parameter(request, param="departement")) - ), - region=validate_region( - str_to_list(clean_parameter(request, param="region")) - ), - est_entrepreneur_individuel=parse_and_validate_bool_field( - request, param="est_entrepreneur_individuel" - ), - section_activite_principale=validate_section_activite_principale( - str_to_list( - clean_parameter(request, param="section_activite_principale") - ) - ), - tranche_effectif_salarie_unite_legale=validate_tranche_effectif_salarie( - str_to_list(clean_parameter(request, param="tranche_effectif_salarie")) - ), - convention_collective_renseignee=parse_and_validate_bool_field( - request, param="convention_collective_renseignee" - ), - egapro_renseignee=parse_and_validate_bool_field( - request, param="egapro_renseignee" - ), - est_bio=parse_and_validate_bool_field(request, param="est_bio"), - est_finess=parse_and_validate_bool_field(request, param="est_finess"), - est_uai=parse_and_validate_bool_field(request, param="est_uai"), - est_collectivite_territoriale=parse_and_validate_bool_field( - request, param="est_collectivite_territoriale" - ), - est_entrepreneur_spectacle=parse_and_validate_bool_field( - request, param="est_entrepreneur_spectacle" - ), - est_association=parse_and_validate_bool_field( - request, param="est_association" - ), - est_organisme_formation=parse_and_validate_bool_field( - request, param="est_organisme_formation" - ), - est_qualiopi=parse_and_validate_bool_field(request, param="est_qualiopi"), - est_rge=parse_and_validate_bool_field(request, param="est_rge"), - est_service_public=parse_and_validate_bool_field( - request, param="est_service_public" - ), - est_societe_mission=match_bool_to_insee_value( - parse_and_validate_bool_field(request, param="est_societe_mission"), - ), - economie_sociale_solidaire_unite_legale=match_bool_to_insee_value( - parse_and_validate_bool_field(request, param="est_ess"), - ), - id_convention_collective=validate_id_convention_collective( - clean_parameter(request, param="id_convention_collective") - ), - id_finess=validate_id_finess(clean_parameter(request, param="id_finess")), - id_uai=validate_id_uai(clean_parameter(request, param="id_uai")), - code_collectivite_territoriale=validate_code_collectivite_territoriale( - str_to_list( - clean_parameter(request, param="code_collectivite_territoriale") - ) - ), - minimal=parse_and_validate_bool_field(request, param="minimal"), - include=validate_selected_fields( - str_to_list(clean_parameter(request, param="include")) - ), - include_admin=validate_selected_fields( - str_to_list(clean_parameter(request, param="include_admin")), - admin=True, - ), - ) + # Map the request parameters to match the Pydantic model's field names + mapped_params = SearchParamsBuilder.map_request_parameters(request) + params = SearchParams(**mapped_params) SearchParamsBuilder.check_and_validate_params(request, params) - return params + return params.dict() @staticmethod def get_geo_search_params(request): - params = SearchParams( - page=parse_and_validate_page(request), - per_page=parse_and_validate_per_page(request), - lat=parse_and_validate_latitude(request), - lon=parse_and_validate_longitude(request), - radius=parse_and_validate_radius(request), - activite_principale_unite_legale=validate_activite_principale( - str_to_list(clean_parameter(request, param="activite_principale")) - ), - section_activite_principale=validate_section_activite_principale( - str_to_list( - clean_parameter(request, param="section_activite_principale") - ) - ), - matching_size=parse_and_validate_matching_size(request), - minimal=parse_and_validate_bool_field(request, param="minimal"), - include=validate_selected_fields( - str_to_list(clean_parameter(request, param="include")) - ), - include_admin=validate_selected_fields( - str_to_list(clean_parameter(request, param="include_admin")), - admin=True, - ), - ) - return params + mapped_params = SearchParamsBuilder.map_request_parameters(request) + params = SearchParams(**mapped_params) + return params.dict() @staticmethod def check_and_validate_params(request, params): - check_empty_params(params) + check_empty_params(params.dict(exclude_unset=True)) ban_params(request, "localisation") validate_inclusion_fields(params.minimal, params.include) validate_date_range( @@ -215,7 +69,7 @@ def check_and_validate_params(request, params): ) # Prevent performance issues by refusing query terms less than 3 characters # unless another param is provided - check_short_terms_and_no_param(params) + check_short_terms_and_no_param(params.dict(exclude_unset=True)) @staticmethod def extract_params(request, search_type): diff --git a/aio/aio-proxy/aio_proxy/request/search_params_model.py b/aio/aio-proxy/aio_proxy/request/search_params_model.py index 6aefb81c..1288d0d3 100644 --- a/aio/aio-proxy/aio_proxy/request/search_params_model.py +++ b/aio/aio-proxy/aio_proxy/request/search_params_model.py @@ -1,4 +1,5 @@ import re +from datetime import date, datetime from aio_proxy.labels.helpers import ( codes_naf, @@ -10,7 +11,13 @@ valid_admin_fields_to_select, valid_fields_to_select, ) -from pydantic import BaseModel, validator +from aio_proxy.request.parsers.insee_bool import match_bool_to_insee_value +from aio_proxy.request.parsers.string_parser import ( + clean_name, + clean_parameter, +) +from aio_proxy.utils.utils import str_to_list +from pydantic import BaseModel, field_validator MAX_PAGE_VALUE = 1000 MIN_PAGE_NUMBER = 0 @@ -45,7 +52,7 @@ class SearchParams(BaseModel): est_qualiopi: bool | None = None est_rge: bool | None = None est_service_public: bool | None = None - est_societe_mission: bool | None = None + est_societe_mission: str | None = None economie_sociale_solidaire_unite_legale: str | None = None id_convention_collective: str | None = None id_finess: str | None = None @@ -54,8 +61,8 @@ class SearchParams(BaseModel): id_rge: str | None = None nom_personne: str | None = None prenoms_personne: str | None = None - min_date_naiss_personne: str | None = None - max_date_naiss_personne: str | None = None + min_date_naiss_personne: datetime | None = None + max_date_naiss_personne: datetime | None = None ca_min: str | None = None ca_max: str | None = None resultat_net_min: float | None = None @@ -71,163 +78,188 @@ class SearchParams(BaseModel): include: list | None = None include_admin: list | None = None - @validator("page", pre=True, always=True) + @field_validator("page", mode="before") def validate_page(cls, page): - page = page - 1 # default 1 - # 1000 is elasticsearch's default page limit - if page <= MIN_PAGE_NUMBER - 1 or page >= MAX_PAGE_VALUE: - raise ValueError( - "Veuillez indiquer un numéro de page entier entre 1 " - "et 1000, par défaut 1." + try: + page = int(page) - 1 # default 1 + # 1000 is elasticsearch's default page limit + if page <= MIN_PAGE_NUMBER - 1 or page >= MAX_PAGE_VALUE: + raise TypeError + except TypeError: + raise TypeError( + "Veuillez indiquer un numéro de page entier entre 1 et " + "1000, par défaut 1." ) return page - @validator("per_page", pre=True, always=True) + @field_validator("per_page", mode="before") def validate_per_page(cls, per_page): max_per_page = 25 min_per_page = 1 - if per_page > max_per_page or per_page < min_per_page: - raise ValueError( + try: + if per_page > max_per_page or per_page < min_per_page: + raise TypeError + return per_page + except TypeError: + raise TypeError( "Veuillez indiquer un `per_page` entre 1 et 25, par défaut 10." ) - return per_page - @validator("terms", pre=True, always=True) + @field_validator("terms", mode="before") def validate_terms(cls, terms): - if terms: - return terms.upper() - return terms + return terms.upper() - @validator("matching_size", pre=True, always=True) + @field_validator("matching_size", mode="before") def validate_matching_size(cls, matching_size): min_matching_size = 0 max_matching_size = 100 - if matching_size <= min_matching_size or matching_size > max_matching_size: - raise ValueError( + try: + if matching_size <= min_matching_size or matching_size > max_matching_size: + raise TypeError + return matching_size + except TypeError: + raise TypeError( "Veuillez indiquer un nombre d'établissements" "connexes entier entre 1 et " "100, par défaut 10." ) - return matching_size - @validator("nature_juridique_unite_legale", pre=True, always=True) - def validate_nature_juridique(cls, list_nature_juridique): + @field_validator("nature_juridique_unite_legale", mode="before") + def validate_nature_juridique(cls, nature_juridique): + list_nature_juridique = str_to_list(clean_parameter(nature_juridique)) for nature_juridique in list_nature_juridique: if nature_juridique not in natures_juridiques: - raise ValueError( + raise TypeError( f"Au moins une nature juridique est non valide. " f"Les natures juridiques valides : " f"{[nature_juridique for nature_juridique in natures_juridiques]}." ) return list_nature_juridique - @validator( - "min_date_naiss_personne", "max_date_naiss_personne", pre=True, always=True + @field_validator("nom_personne", "prenoms_personne", mode="before") + def validate_nom(cls, nom): + return clean_name(nom) + + @field_validator( + "min_date_naiss_personne", "max_date_naiss_personne", mode="before" ) - def validate_date(cls, date): - return date.fromisoformat(date) + def validate_date(cls, date_string): + try: + return date.fromisoformat(date_string) + except ValueError: + raise TypeError( + "Veuillez indiquer une date sous" + "le format : aaaa-mm-jj. Exemple : '1990-01-02'" + ) - @validator( + @field_validator( "ca_min", "ca_max", "resultat_net_min", "resultat_net_max", - pre=True, - always=True, + mode="before", ) def validate_long_int(cls, int_val): # Elasticsearch `long` type maxes out at this range min_val = -9223372036854775295 max_val = 9223372036854775295 - - if min_val <= int_val <= max_val: - return int_val - else: - raise ValueError( + try: + if min_val <= int(int_val) <= max_val: + return int(int_val) + except TypeError: + raise TypeError( f"Veuillez indiquer un entier entre {min_val} et {max_val}." ) - @validator("type_personne", pre=True, always=True) + @field_validator("type_personne", mode="before") def validate_type_personne(cls, type_personne): - if type_personne not in ["ELU", "DIRIGEANT"]: - raise ValueError( + if type_personne.upper() not in ["ELU", "DIRIGEANT"]: + raise TypeError( "type_personne doit prendre la valeur 'dirigeant' ou 'elu' !" ) return type_personne - @validator("etat_administratif_unite_legale", pre=True, always=True) + @field_validator("etat_administratif_unite_legale", mode="before") def validate_etat_administratif(cls, etat_administratif): - if etat_administratif not in ["A", "C"]: - raise ValueError("L'état administratif doit prendre la valeur 'A' ou 'C' !") + if etat_administratif.upper() not in ["A", "C"]: + raise TypeError("L'état administratif doit prendre la valeur 'A' ou 'C' !") return etat_administratif - @validator("activite_principale_unite_legale", pre=True, always=True) + @field_validator("activite_principale_unite_legale", mode="before") def validate_activite_principale(cls, activite_principale_unite_legale): + list_activite_principale = str_to_list( + clean_parameter(activite_principale_unite_legale) + ) length_activite_principale = 6 - for activite_principale in activite_principale_unite_legale: + for activite_principale in list_activite_principale: if len(activite_principale) != length_activite_principale: - raise ValueError( + raise TypeError( "Chaque activité principale doit contenir 6 caractères." ) if activite_principale not in codes_naf: - raise ValueError("Au moins une des activités principales est inconnue.") - return activite_principale_unite_legale + raise TypeError("Au moins une des activités principales est inconnue.") + return list_activite_principale - @validator("categorie_entreprise", pre=True, always=True) - def validate_categorie_entreprise(cls, list_categorie_entreprise): + @field_validator("categorie_entreprise", mode="before") + def validate_categorie_entreprise(cls, categorie_entreprise): + list_categorie_entreprise = str_to_list(clean_parameter(categorie_entreprise)) for categorie_entreprise in list_categorie_entreprise: if categorie_entreprise not in ["GE", "PME", "ETI"]: - raise ValueError( + raise TypeError( "Chaque catégorie d'entreprise doit prendre une de ces " "valeurs 'GE', 'PME' ou 'ETI'." ) return list_categorie_entreprise - @validator("commune", pre=True, always=True) - def validate_commune(cls, list_commune): + @field_validator("commune", mode="before") + def validate_commune(cls, commune): + list_commune = str_to_list(clean_parameter(commune)) length_code_commune = 5 for code_commune in list_commune: if len(code_commune) != length_code_commune: - raise ValueError("Chaque code commune doit contenir 5 caractères !") + raise TypeError("Chaque code commune doit contenir 5 caractères !") codes_valides = r"^([013-9]\d|2[AB1-9])\d{3}$" if not re.search(codes_valides, code_commune): - raise ValueError("Au moins un des codes communes est non valide.") + raise TypeError("Au moins un des codes communes est non valide.") return list_commune - @validator("code_postal", pre=True, always=True) - def validate_code_postal(cls, list_code_postal): + @field_validator("code_postal", mode="before") + def validate_code_postal(cls, code_postal): + list_code_postal = str_to_list(clean_parameter(code_postal)) length_cod_postal = 5 for code_postal in list_code_postal: if len(code_postal) != length_cod_postal: - raise ValueError("Chaque code postal doit contenir 5 caractères !") + raise TypeError("Chaque code postal doit contenir 5 caractères !") codes_valides = "^((0[1-9])|([1-8][0-9])|(9[0-8])|(2A)|(2B))[0-9]{3}$" if not re.search(codes_valides, code_postal): - raise ValueError("Au moins un code postal est non valide.") + raise TypeError("Au moins un code postal est non valide.") return list_code_postal - @validator("departement", pre=True, always=True) - def validate_departement(cls, list_departement): + @field_validator("departement", mode="before") + def validate_departement(cls, departement): + list_departement = str_to_list(clean_parameter(departement)) for departement in list_departement: if departement not in departements: - raise ValueError( + raise TypeError( f"Au moins un département est non valide." f" Les départements valides" f" : {[dep for dep in departements]}" ) return list_departement - @validator("region", pre=True, always=True) - def validate_region(cls, list_region): + @field_validator("region", mode="before") + def validate_region(cls, region): + list_region = str_to_list(clean_parameter(region)) for region in list_region: if region not in regions: - raise ValueError( + raise TypeError( f"Au moins une region est non valide." f" Les région valides" f" : {regions}" ) return list_region - @validator( + @field_validator( "est_entrepreneur_individuel", "convention_collective_renseignee", "egapro_renseignee", @@ -241,134 +273,142 @@ def validate_region(cls, list_region): "est_qualiopi", "est_rge", "est_service_public", - "est_societe_mission", - "economie_sociale_solidaire_unite_legale", "minimal", - pre=True, - always=True, + mode="before", + ) + def validate_bool(cls, boolean, info): + param_name = info.field_name + if boolean.upper() not in ["TRUE", "FALSE"]: + raise TypeError(f"{param_name} doit prendre la valeur 'true' ou 'false' !") + return boolean.upper() == "TRUE" + + @field_validator( + "est_societe_mission", "economie_sociale_solidaire_unite_legale", mode="before" ) - def validate_bool(cls, bool, field): - param_name = field.name - if bool not in ["TRUE", "FALSE"]: - raise ValueError(f"{param_name} doit prendre la valeur 'true' ou 'false' !") - return bool == "TRUE" - - @validator("section_activite_principale", pre=True, always=True) - def validate_section_activite_principale(cls, list_section_activite_principale): + def validate_societe_a_mission(cls, boolean, info): + param_name = info.field_name + if boolean.upper() not in ["TRUE", "FALSE"]: + # Using TypeError because it is not wrapped in a Validation Error + # in Pydantic + raise TypeError(f"{param_name} doit prendre la valeur 'true' ou 'false' !") + return match_bool_to_insee_value(boolean.upper() == "TRUE") + + @field_validator("section_activite_principale", mode="before") + def validate_section_activite_principale(cls, section_activite_principale): + list_section_activite_principale = str_to_list( + clean_parameter(section_activite_principale) + ) for section_activite_principale in list_section_activite_principale: if section_activite_principale not in sections_codes_naf: - raise ValueError( + raise TypeError( "Au moins une section d'activité principale est non valide." ) return list_section_activite_principale - @validator("tranche_effectif_salarie_unite_legale", pre=True, always=True) - def validate_tranche_effectif_salarie(cls, list_tranche_effectif_salarie): + @field_validator("tranche_effectif_salarie_unite_legale", mode="before") + def validate_tranche_effectif_salarie(cls, tranche_effectif_salarie): + list_tranche_effectif_salarie = str_to_list( + clean_parameter(tranche_effectif_salarie) + ) length_tranche_effectif_salarie = 2 for tranche_effectif_salarie in list_tranche_effectif_salarie: if len(tranche_effectif_salarie) != length_tranche_effectif_salarie: - raise ValueError("Chaque tranche salariés doit contenir 2 caractères.") + raise TypeError("Chaque tranche salariés doit contenir 2 caractères.") if tranche_effectif_salarie not in tranches_effectifs: - raise ValueError("Au moins une tranche salariés est non valide.") + raise TypeError("Au moins une tranche salariés est non valide.") return list_tranche_effectif_salarie - @validator("id_convention_collective", pre=True, always=True) + @field_validator("id_convention_collective", mode="before") def validate_id_convention_collective(cls, id_convention_collective): length_convention_collective = 4 if len(id_convention_collective) != length_convention_collective: - raise ValueError( + raise TypeError( "L'identifiant de convention collective doit contenir 4 caractères." ) return id_convention_collective - @validator("id_finess", pre=True, always=True) + @field_validator("id_finess", mode="before") def validate_id_finess(cls, id_finess): len_id_finess = 9 if len(id_finess) != len_id_finess: - raise ValueError("L'identifiant FINESS doit contenir 9 caractères.") + raise TypeError("L'identifiant FINESS doit contenir 9 caractères.") return id_finess - @validator("id_uai", pre=True, always=True) + @field_validator("id_uai", mode="before") def validate_id_uai(cls, id_uai): length_id_uai = 8 if len(id_uai) != length_id_uai: - raise ValueError("L'identifiant UAI doit contenir 8 caractères.") + raise TypeError("L'identifiant UAI doit contenir 8 caractères.") return id_uai - @validator("code_collectivite_territoriale", pre=True, always=True) - def validate_code_collectivite_territoriale(cls, list_code_cc): + @field_validator("code_collectivite_territoriale", mode="before") + def validate_code_collectivite_territoriale(cls, code_cc): + list_code_cc = str_to_list(clean_parameter(code_cc)) min_len_code_collectivite_territoriale = 2 for code_collectivite_territoriale in list_code_cc: if ( len(code_collectivite_territoriale) < min_len_code_collectivite_territoriale ): - raise ValueError( + raise TypeError( "Chaque identifiant code insee d'une collectivité " "territoriale doit contenir au moins 2 caractères." ) return list_code_cc - @validator("include", pre=True, always=True) - def validate_include(cls, list_fields): - valid_fields_to_check = valid_fields_to_select - for field in list_fields: - if field not in valid_fields_to_check: - valid_fields_lowercase = [ - field.lower() for field in valid_fields_to_check - ] - raise ValueError( - f"Au moins un champ à inclure est non valide. " - f"Les champs valides : {valid_fields_lowercase}." - ) - return list_fields - - @validator("include_admin", pre=True, always=True) - def validate_include_admin(cls, list_fields): - valid_fields_to_check = valid_admin_fields_to_select + @field_validator("include", "include_admin", mode="before") + def validate_include(cls, fields, info): + list_fields = str_to_list(clean_parameter(fields)) + if info.field_name == "include_admin": + valid_fields_to_check = valid_admin_fields_to_select + else: + valid_fields_to_check = valid_fields_to_select for field in list_fields: if field not in valid_fields_to_check: valid_fields_lowercase = [ field.lower() for field in valid_fields_to_check ] - raise ValueError( + raise TypeError( f"Au moins un champ à inclure est non valide. " f"Les champs valides : {valid_fields_lowercase}." ) return list_fields - @validator("lat", pre=True, always=True) + @field_validator("lat", mode="before") def validate_lat(cls, lat): min_latitude = -90 max_latitude = 90 if lat == "nan": - raise ValueError("Veuillez indiquer une latitude entre -90° et 90°.") + raise TypeError("Veuillez indiquer une latitude entre -90° et 90°.") try: lat = float(lat) - except ValueError: - raise ValueError("Veuillez indiquer une latitude entre -90° et 90°.") + except TypeError: + raise TypeError("Veuillez indiquer une latitude entre -90° et 90°.") if lat > max_latitude or lat < min_latitude: - raise ValueError("Veuillez indiquer une latitude entre -90° et 90°.") + raise TypeError("Veuillez indiquer une latitude entre -90° et 90°.") - @validator("lon", pre=True, always=True) + @field_validator("lon", mode="before") def validate_lon(cls, lon): min_longitude = -180 max_longitude = 180 if lon == "nan": - raise ValueError("Veuillez indiquer une longitude entre -180° et 180°.") + raise TypeError("Veuillez indiquer une longitude entre -180° et 180°.") try: lon = float(lon) - except ValueError: - raise ValueError("Veuillez indiquer une longitude entre -180° et 180°.") + except TypeError: + raise TypeError("Veuillez indiquer une longitude entre -180° et 180°.") if lon > max_longitude or lon < min_longitude: - raise ValueError("Veuillez indiquer une longitude entre -180° et 180°.") + raise TypeError("Veuillez indiquer une longitude entre -180° et 180°.") return lon - @validator("radius", pre=True, always=True) + @field_validator("radius", mode="before") def validate_radius(cls, radius): - if radius <= MIN_RADIUS or radius > MAX_RADIUS: - raise ValueError( + try: + if float(radius) <= MIN_RADIUS or float(radius) > MAX_RADIUS: + raise TypeError + return float(radius) + except TypeError: + raise TypeError( "Veuillez indiquer un radius entier ou flottant " "bentre 0 et 50 (en km)." ) - return radius diff --git a/aio/aio-proxy/aio_proxy/response/helpers.py b/aio/aio-proxy/aio_proxy/response/helpers.py index 52fa4f97..ea831ab5 100755 --- a/aio/aio-proxy/aio_proxy/response/helpers.py +++ b/aio/aio-proxy/aio_proxy/response/helpers.py @@ -36,11 +36,11 @@ def hash_string(string: str): def create_fields_to_include(search_params): - if search_params.minimal: - if search_params.include is None: + if search_params["minimal"]: + if search_params["include"] is None: return [] else: - return search_params.include + return search_params["include"] else: return [ "SIEGE", @@ -52,7 +52,7 @@ def create_fields_to_include(search_params): def create_admin_fields_to_include(search_params): - if search_params.include_admin is None: + if search_params["include_admin"] is None: return [] else: - return search_params.include_admin + return search_params["include_admin"] diff --git a/aio/aio-proxy/aio_proxy/response/response_builder.py b/aio/aio-proxy/aio_proxy/response/response_builder.py index 864053d3..62854a07 100755 --- a/aio/aio-proxy/aio_proxy/response/response_builder.py +++ b/aio/aio-proxy/aio_proxy/response/response_builder.py @@ -6,11 +6,11 @@ class ResponseBuilder: def __init__(self, search_params, es_search_results): self.total_results = min(int(es_search_results.total_results), 10000) - self.per_page = search_params.per_page + self.per_page = search_params["per_page"] self.results = format_search_results( es_search_results.es_search_results, search_params ) - self.page = search_params.page + 1 + self.page = search_params["page"] + 1 self.total_pages = self.calculate_total_pages() response = ResponseModel( results=self.results, diff --git a/aio/aio-proxy/aio_proxy/search/filters/boolean.py b/aio/aio-proxy/aio_proxy/search/filters/boolean.py index caf92b37..473050bb 100644 --- a/aio/aio-proxy/aio_proxy/search/filters/boolean.py +++ b/aio/aio-proxy/aio_proxy/search/filters/boolean.py @@ -7,7 +7,7 @@ def filter_search_by_bool_fields_unite_legale( search_params, filters_to_include: list, ): - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): should_apply_bool_filter = ( param_value is not None and param_name in filters_to_include ) @@ -38,7 +38,7 @@ def filter_search_by_bool_fields_unite_legale( def filter_search_by_bool_nested_fields_unite_legale( search, search_params, filters_to_include: list, path ): - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): should_apply_bool_filter = ( param_value is not None and param_name in filters_to_include ) diff --git a/aio/aio-proxy/aio_proxy/search/filters/nested_etablissements_filters.py b/aio/aio-proxy/aio_proxy/search/filters/nested_etablissements_filters.py index eed28813..88fb0cb7 100644 --- a/aio/aio-proxy/aio_proxy/search/filters/nested_etablissements_filters.py +++ b/aio/aio-proxy/aio_proxy/search/filters/nested_etablissements_filters.py @@ -43,7 +43,7 @@ def build_etablissements_filters(search_params): must_not_filters = [] # params is the list of parameters (filters) provided in the request - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): should_apply_text_filter = ( param_value is not None and param_name in text_filters ) @@ -115,7 +115,7 @@ def build_nested_etablissements_filters_query(search_params, with_inner_hits=Fal if with_inner_hits: filters_query["nested"]["inner_hits"] = { - "size": search_params.matching_size, + "size": search_params["matching_size"], "sort": { "unite_legale.etablissements.etat_administratif": {"order": "asc"} }, diff --git a/aio/aio-proxy/aio_proxy/search/filters/term_filters.py b/aio/aio-proxy/aio_proxy/search/filters/term_filters.py index 05c9738a..4e223cb6 100644 --- a/aio/aio-proxy/aio_proxy/search/filters/term_filters.py +++ b/aio/aio-proxy/aio_proxy/search/filters/term_filters.py @@ -9,7 +9,7 @@ def filter_term_search_unite_legale( """Use filters to reduce search results.""" # search_params is the object containing the list of parameters (filters) provided # in the request - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): if param_value is not None and param_name in filters_to_include: search = search.filter( "term", @@ -28,7 +28,7 @@ def filter_term_list_search_unite_legale( """Use filters to reduce search results.""" # search_params is the object containing the list of parameters (filters) provided # in the request - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): if param_value is not None and param_name in filters_to_include: search = search.filter( "terms", diff --git a/aio/aio-proxy/aio_proxy/search/geo_search.py b/aio/aio-proxy/aio_proxy/search/geo_search.py index bc969f78..6b217370 100644 --- a/aio/aio-proxy/aio_proxy/search/geo_search.py +++ b/aio/aio-proxy/aio_proxy/search/geo_search.py @@ -28,17 +28,19 @@ def build_es_search_geo_query(es_search_builder): "bool": { "filter": { "geo_distance": { - "distance": f"{es_search_builder.search_params.radius}km", + "distance": ( + f'{es_search_builder.search_params["radius"]}km' + ), "unite_legale.etablissements.coordonnees": { - "lat": es_search_builder.search_params.lat, - "lon": es_search_builder.search_params.lon, + "lat": es_search_builder.search_params["lat"], + "lon": es_search_builder.search_params["lon"], }, }, } } }, "inner_hits": { - "size": es_search_builder.search_params.matching_size, + "size": es_search_builder.search_params["matching_size"], }, } } diff --git a/aio/aio-proxy/aio_proxy/search/helpers/bilan_filters_used.py b/aio/aio-proxy/aio_proxy/search/helpers/bilan_filters_used.py index 300a7a21..830b4740 100644 --- a/aio/aio-proxy/aio_proxy/search/helpers/bilan_filters_used.py +++ b/aio/aio-proxy/aio_proxy/search/helpers/bilan_filters_used.py @@ -5,7 +5,7 @@ def is_any_bilan_filter_used(search_params) -> bool: "resultat_net_min", "resultat_net_max", ] - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): if param_value is not None and param_name in bilan_filters: return True return False diff --git a/aio/aio-proxy/aio_proxy/search/helpers/etablissements_filters_used.py b/aio/aio-proxy/aio_proxy/search/helpers/etablissements_filters_used.py index 6bc060b3..c3ce1e06 100644 --- a/aio/aio-proxy/aio_proxy/search/helpers/etablissements_filters_used.py +++ b/aio/aio-proxy/aio_proxy/search/helpers/etablissements_filters_used.py @@ -15,7 +15,7 @@ def is_any_etablissement_filter_used(search_params) -> bool: "id_rge", "region", ] - for param_name, param_value in vars(search_params).items(): + for param_name, param_value in search_params.items(): if param_value is not None and param_name in etablissements_filters: return True return False diff --git a/aio/aio-proxy/aio_proxy/search/helpers/exclude_etablissements.py b/aio/aio-proxy/aio_proxy/search/helpers/exclude_etablissements.py index 381c385c..99dc65cc 100644 --- a/aio/aio-proxy/aio_proxy/search/helpers/exclude_etablissements.py +++ b/aio/aio-proxy/aio_proxy/search/helpers/exclude_etablissements.py @@ -1,8 +1,8 @@ def exclude_etablissements_from_search(es_search_builder): # # By default, exclude etablissements list from response include_etablissements = ( - es_search_builder.search_params.include_admin - and "ETABLISSEMENTS" in es_search_builder.search_params.include_admin + es_search_builder.search_params["include_admin"] + and "ETABLISSEMENTS" in es_search_builder.search_params["include_admin"] ) if not include_etablissements: diff --git a/aio/aio-proxy/aio_proxy/search/helpers/helpers.py b/aio/aio-proxy/aio_proxy/search/helpers/helpers.py index bc959a83..8970d032 100644 --- a/aio/aio-proxy/aio_proxy/search/helpers/helpers.py +++ b/aio/aio-proxy/aio_proxy/search/helpers/helpers.py @@ -36,7 +36,7 @@ def page_through_results(es_search_builder): ElasticSearchBuilder Instance with pagination """ - size = es_search_builder.search_params.per_page - offset = es_search_builder.search_params.page * size + size = es_search_builder.search_params["per_page"] + offset = es_search_builder.search_params["page"] * size search_client = es_search_builder.es_search_client return search_client[offset : (offset + size)] diff --git a/aio/aio-proxy/aio_proxy/search/queries/bilan.py b/aio/aio-proxy/aio_proxy/search/queries/bilan.py index 78c25853..758ebc9b 100644 --- a/aio/aio-proxy/aio_proxy/search/queries/bilan.py +++ b/aio/aio-proxy/aio_proxy/search/queries/bilan.py @@ -10,7 +10,7 @@ def search_bilan( search_options = [] bilan_filters = [] for filter in bilan_filters_to_include: - filter_value = getattr(search_params, filter) + filter_value = search_params.get(filter, None) if filter_value is not None: if "min" in filter: operator = "gte" diff --git a/aio/aio-proxy/aio_proxy/search/queries/person.py b/aio/aio-proxy/aio_proxy/search/queries/person.py index 36387bee..3f777698 100755 --- a/aio/aio-proxy/aio_proxy/search/queries/person.py +++ b/aio/aio-proxy/aio_proxy/search/queries/person.py @@ -15,7 +15,7 @@ def search_person( person_filters = [] boost_queries = [] # Nom - nom_person = getattr(search_params, param_nom) + nom_person = search_params.get(param_nom, "None") if nom_person: # match queries returns any document containing the search item, # even if it contains another item @@ -51,7 +51,7 @@ def search_person( ) # Prénoms - prenoms_person = getattr(search_params, param_prenom) + prenoms_person = search_params.get("param_prenom", None) if prenoms_person: # Same logic as "nom" is used for "prenoms" for prenom in prenoms_person.split(" "): @@ -81,7 +81,7 @@ def search_person( ) # Date de naissance - min_date_naiss_person = getattr(search_params, param_date_min) + min_date_naiss_person = search_params.get("param_date_min", None) if min_date_naiss_person: person_filters.append( { @@ -96,7 +96,7 @@ def search_person( } ) - max_date_naiss_person = getattr(search_params, param_date_max) + max_date_naiss_person = search_params.get("param_date_max", None) if max_date_naiss_person: person_filters.append( { diff --git a/aio/aio-proxy/aio_proxy/search/text_search.py b/aio/aio-proxy/aio_proxy/search/text_search.py index 97bee5a2..6bb8812e 100755 --- a/aio/aio-proxy/aio_proxy/search/text_search.py +++ b/aio/aio-proxy/aio_proxy/search/text_search.py @@ -27,7 +27,7 @@ def build_es_search_text_query(es_search_builder): - query_terms = es_search_builder.search_params.terms + query_terms = es_search_builder.search_params["terms"] # Filter by siren/siret first (if query is a `siren` or 'siret' number), # and return search results directly without text search. if is_siren(query_terms) or is_siret(query_terms): @@ -123,7 +123,7 @@ def build_es_search_text_query(es_search_builder): if query_terms: text_query = build_text_query( terms=query_terms, - matching_size=es_search_builder.search_params.matching_size, + matching_size=es_search_builder.search_params["matching_size"], ) text_query_with_filters = ( add_nested_etablissements_filters_to_text_query( @@ -147,16 +147,15 @@ def build_es_search_text_query(es_search_builder): Q(filters_etablissements_query_with_inner_hits) ) ) - else: + elif query_terms: # Text search only without etablissements filters - if query_terms: - text_query = build_text_query( - terms=query_terms, - matching_size=es_search_builder.search_params.matching_size, - ) - es_search_builder.es_search_client = ( - es_search_builder.es_search_client.query(Q(text_query)) - ) + text_query = build_text_query( + terms=query_terms, + matching_size=es_search_builder.search_params["matching_size"], + ) + es_search_builder.es_search_client = ( + es_search_builder.es_search_client.query(Q(text_query)) + ) # Search by chiffre d'affaire or resultat net in bilan_financier is_bilan_bilan_used = is_any_bilan_filter_used(es_search_builder.search_params) @@ -173,7 +172,7 @@ def build_es_search_text_query(es_search_builder): ) # Search 'élus' only - type_personne = es_search_builder.search_params.type_personne + type_personne = es_search_builder.search_params["type_personne"] if type_personne == "ELU": es_search_builder.es_search_client = search_person( es_search_builder.es_search_client, @@ -240,7 +239,7 @@ def build_es_search_text_query(es_search_builder): "nom_personne", "prenoms_personne", ]: - if getattr(es_search_builder.search_params, item): + if es_search_builder.search_params.get(item, None): es_search_builder.has_full_text_query = True exclude_etablissements_from_search(es_search_builder)