diff --git a/.deepsource.toml b/.deepsource.toml index cb880ada44..e784e76ee2 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -12,8 +12,8 @@ runtime_version = "3.x.x" [[transformers]] name = "black" -enabled = true +enabled = false [[transformers]] name = "isort" -enabled = true +enabled = false diff --git a/.env b/.env index f1766a92a2..d0c7a733a6 100644 --- a/.env +++ b/.env @@ -46,8 +46,9 @@ POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres POSTGRES_EXPOSE=127.0.0.1:5432 -# Triton ML inference server +# Triton ML inference server & TF Serving TRITON_HOST=triton +TF_SERVING_HOST=tf_serving # InfluxDB INFLUXDB_HOST= @@ -62,15 +63,12 @@ INFLUXDB_AUTH_TOKEN= # MONGO_URI=mongodb://mongodb.po_default:27017 MONGO_URI=mongodb://mongodb:27017 +# Redis +REDIS_HOST=redis + # OpenFoodFacts API OFF_PASSWORD= OFF_USER= # Utils -SENTRY_DSN= - -# Workers -IPC_AUTHKEY=ipc -IPC_HOST=workers -IPC_PORT=6650 -WORKER_COUNT=8 +SENTRY_DSN= \ No newline at end of file diff --git a/.github/workflows/container-deploy-ml.yml b/.github/workflows/container-deploy-ml.yml index b1019d3d78..d27d50161d 100644 --- a/.github/workflows/container-deploy-ml.yml +++ b/.github/workflows/container-deploy-ml.yml @@ -71,7 +71,7 @@ jobs: echo "COMPOSE_HTTP_TIMEOUT=120" >> .env echo "COMPOSE_PATH_SEPARATOR=;" >> .env echo "COMPOSE_PROJECT_NAME=robotoff-ml" >> .env - echo "COMPOSE_FILE=docker/ml.yml" >> .env + echo "COMPOSE_FILE=docker-compose.yml;docker/ml.yml" >> .env echo "RESTART_POLICY=always" >> .env echo "TRITON_EXPOSE_HTTP=8003" >> .env diff --git a/.github/workflows/container-deploy.yml b/.github/workflows/container-deploy.yml index 4f1e28c64a..ff8c3d77b8 100644 --- a/.github/workflows/container-deploy.yml +++ b/.github/workflows/container-deploy.yml @@ -111,10 +111,7 @@ jobs: # Set app variables echo "ROBOTOFF_INSTANCE=${{ env.ROBOTOFF_INSTANCE }}" >> .env echo "ROBOTOFF_DOMAIN=${{ env.ROBOTOFF_DOMAIN }}" >> .env - echo "IPC_AUTHKEY=${{ secrets.IPC_AUTHKEY }}" >> .env - echo "IPC_HOST=0.0.0.0" >> .env - echo "IPC_PORT=6650" >> .env - echo "WORKER_COUNT=8" >> .env + echo "REDIS_HOST=redis" >> .env echo "POSTGRES_HOST=postgres" >> .env echo "POSTGRES_DB=postgres" >> .env echo "POSTGRES_USER=postgres" >> .env diff --git a/Makefile b/Makefile index 9315599337..d25287fb0e 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ up: @echo "🥫 Building and starting containers …" docker network create po_default || true ifdef service - ${DOCKER_COMPOSE} up -d ${service} 2>&1 + ${DOCKER_COMPOSE} up --remove-orphans -d ${service} 2>&1 else ${DOCKER_COMPOSE} up -d 2>&1 endif @@ -173,26 +173,26 @@ health: i18n-compile: @echo "🥫 Compiling translations …" # Note it's important to have --no-deps, to avoid launching a concurrent postgres instance - ${DOCKER_COMPOSE} run --rm --entrypoint bash --no-deps workers -c "cd i18n && . compile.sh" + ${DOCKER_COMPOSE} run --rm --entrypoint bash --no-deps worker_high -c "cd i18n && . compile.sh" unit-tests: @echo "🥫 Running tests …" # run tests in worker to have more memory # also, change project name to run in isolation - ${DOCKER_COMPOSE_TEST} run --rm workers poetry run pytest --cov-report xml --cov=robotoff tests/unit + ${DOCKER_COMPOSE_TEST} run --rm worker_high poetry run pytest --cov-report xml --cov=robotoff tests/unit integration-tests: @echo "🥫 Running integration tests …" # run tests in worker to have more memory # also, change project name to run in isolation - ${DOCKER_COMPOSE_TEST} run --rm workers poetry run pytest -vv --cov-report xml --cov=robotoff --cov-append tests/integration + ${DOCKER_COMPOSE_TEST} run --rm worker_high poetry run pytest -vv --cov-report xml --cov=robotoff --cov-append tests/integration ( ${DOCKER_COMPOSE_TEST} down -v || true ) # interactive testings # usage: make pytest args='test/unit/my-test.py --pdb' pytest: guard-args @echo "🥫 Running test: ${args} …" - ${DOCKER_COMPOSE_TEST} run --rm workers poetry run pytest ${args} + ${DOCKER_COMPOSE_TEST} run --rm worker_high poetry run pytest ${args} #------------# # Production # diff --git a/doc/how-to-guides/deployment/maintenance.md b/doc/how-to-guides/deployment/maintenance.md index 9c2d707ef8..591fca7ce6 100644 --- a/doc/how-to-guides/deployment/maintenance.md +++ b/doc/how-to-guides/deployment/maintenance.md @@ -47,7 +47,7 @@ robotoff_api_1 /bin/sh -c /docker-entrypo ... Up 0.0.0.0:5500->55 /tcp robotoff_postgres_1 docker-entrypoint.sh postg ... Up 127.0.0.1:5432->5432/tcp robotoff_scheduler_1 /bin/sh -c /docker-entrypo ... Up -robotoff_workers_1 /bin/sh -c /docker-entrypo ... Up +robotoff_worker_low_1 /bin/sh -c /docker-entrypo ... Up ``` ## Database backup and restore diff --git a/doc/introduction/architecture.md b/doc/introduction/architecture.md index d3e21627c2..bc053ac3e8 100644 --- a/doc/introduction/architecture.md +++ b/doc/introduction/architecture.md @@ -7,12 +7,12 @@ Robotoff is made of several services: - the public _API_ service - the _scheduler_, responsible for launching recurrent tasks (downloading new dataset, processing insights automatically,...) [^scheduler] - the _workers_, responsible for all long-lasting tasks +- a _redis_ instance -Communication between API and Workers happens through ipc events. [^ipc_events] +Communication between API and workers happens through Redis DB using [rq](https://python-rq.org). [^worker_job] [^scheduler]: See `scheduler.run` - -[^ipc_events]: See `robotoff.workers.client` and `robotoff.workers.listener` +[^worker_job]: See `robotoff.workers.queues` and `robotoff.workers.tasks` Robotoff allows to predict many information (also called _insights_), mostly from the product images or OCR. @@ -58,7 +58,7 @@ Some insights with high confidence are applied automatically, 10 minutes after i Robotoff is also notified by Product Opener every time a product is updated or deleted [^product_update]. This is used to delete insights associated with deleted products, or to update them accordingly. -[^product_update]: see `workers.tasks.product_updated` and `workers.tasks.delete_product_insights` +[^product_update]: see `workers.tasks.product_updated` and `workers.tasks.delete_product_insights_job` [^annotate]: see `robotoff.insights.annotate` diff --git a/docker-compose.yml b/docker-compose.yml index 40b60b746e..0a08b42fb6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" - -x-robotoff-base: &robotoff-base +x-robotoff-base: + &robotoff-base restart: $RESTART_POLICY image: ghcr.io/openfoodfacts/robotoff:${TAG} volumes: @@ -9,17 +9,15 @@ x-robotoff-base: &robotoff-base - ./tf_models:/opt/robotoff/tf_models - ./models:/opt/robotoff/models -x-robotoff-base-env: &robotoff-base-env +x-robotoff-base-env: + &robotoff-base-env ROBOTOFF_INSTANCE: ROBOTOFF_DOMAIN: ROBOTOFF_SCHEME: STATIC_OFF_DOMAIN: GUNICORN_NUM_WORKERS: - IPC_AUTHKEY: - IPC_HOST: workers - IPC_PORT: - WORKER_COUNT: ROBOTOFF_UPDATED_PRODUCT_WAIT: + REDIS_HOST: POSTGRES_HOST: POSTGRES_DB: POSTGRES_USER: @@ -43,19 +41,33 @@ services: <<: *robotoff-base environment: *robotoff-base-env mem_limit: 2g - depends_on: - - workers ports: - "${ROBOTOFF_EXPOSE:-5500}:5500" networks: - webnet - workers: + worker_high: <<: *robotoff-base - command: poetry run robotoff-cli run workers - environment: - <<: *robotoff-base-env - REAL_TIME_IMAGE_PREDICTION: 1 + deploy: + mode: replicated + replicas: 6 + command: poetry run robotoff-cli run-worker robotoff-high + environment: *robotoff-base-env + depends_on: + - postgres + mem_limit: 8g + networks: + - webnet + extra_hosts: + - host.docker.internal:host-gateway + + worker_low: + <<: *robotoff-base + deploy: + mode: replicated + replicas: 2 + command: poetry run robotoff-cli run-worker robotoff-low robotoff-high + environment: *robotoff-base-env depends_on: - postgres mem_limit: 8g @@ -67,7 +79,7 @@ services: scheduler: <<: *robotoff-base environment: *robotoff-base-env - command: poetry run robotoff-cli run scheduler + command: poetry run robotoff-cli run-scheduler mem_limit: 4g networks: - webnet @@ -89,6 +101,19 @@ services: networks: - webnet + redis: + restart: $RESTART_POLICY + image: redis:7.0.5-alpine + volumes: + - redis-data:/data + environment: + REDIS_ARGS: --save 60 1000 --appendonly yes + mem_limit: 4g + ports: + - "${REDIS_EXPOSE:-127.0.0.1:6379}:6379" + networks: + - webnet + elasticsearch: restart: $RESTART_POLICY image: raphael0202/elasticsearch @@ -113,6 +138,8 @@ services: volumes: postgres-data: es-data: + redis-data: + name: ${COMPOSE_PROJECT_NAME:-robotoff}_redis-data networks: webnet: diff --git a/docker/dev.yml b/docker/dev.yml index 671136c50f..e318f34b81 100644 --- a/docker/dev.yml +++ b/docker/dev.yml @@ -50,9 +50,20 @@ services: - robotoff.openfoodfacts.localhost - api webnet: - workers: + worker_high: <<: *robotoff-dev <<: *networks-productopener-local + deploy: + mode: replicated + # Only 1 replica is easier to deal with for local dev + replicas: 1 + worker_low: + <<: *robotoff-dev + <<: *networks-productopener-local + deploy: + mode: replicated + # Only 1 replica is easier to deal with for local dev + replicas: 1 scheduler: <<: *networks-productopener-local <<: *robotoff-dev diff --git a/docker/ml.yml b/docker/ml.yml index bee6d42c84..187b6a4d02 100644 --- a/docker/ml.yml +++ b/docker/ml.yml @@ -8,7 +8,7 @@ services: - 8501:8501 - 8500:8500 volumes: - - ../tf_models:/models + - ./tf_models:/models entrypoint: "tensorflow_model_server --port=8500 --rest_api_port=8501 --model_config_file=/models/models.config" mem_limit: 10g networks: @@ -28,11 +28,8 @@ services: - ${TRITON_EXPOSE_GRPC:-8001}:8001 - ${TRITON_EXPOSE_METRICS:-8002}:8002 volumes: - - ../models:/models + - ./models:/models entrypoint: "tritonserver --model-repository=/models" mem_limit: 10g networks: - webnet - -networks: - webnet: diff --git a/poetry.lock b/poetry.lock index cbeec5d067..ec545f06d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -24,6 +24,25 @@ tornado = ["tornado (>=4.3)"] twisted = ["twisted"] zookeeper = ["kazoo"] +[[package]] +name = "arrow" +version = "1.2.3" +description = "Better dates & times for Python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +python-dateutil = ">=2.7.0" + +[[package]] +name = "async-timeout" +version = "4.0.2" +description = "Timeout context manager for asyncio programs" +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "attrs" version = "22.1.0" @@ -193,7 +212,7 @@ optional = false python-versions = ">=3.6" [package.extras] -dev = ["pylint", "mypy", "black", "coveralls", "pytest-cov", "pytest (>=5)"] +dev = ["pytest (>=5)", "pytest-cov", "coveralls", "black", "mypy", "pylint"] [[package]] name = "distlib" @@ -380,6 +399,24 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "flask" +version = "2.0.3" +description = "A simple framework for building complex web applications." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +click = ">=7.1.2" +itsdangerous = ">=2.0" +Jinja2 = ">=3.0" +Werkzeug = ">=2.0" + +[package.extras] +async = ["asgiref (>=3.2)"] +dotenv = ["python-dotenv"] + [[package]] name = "fonttools" version = "4.38.0" @@ -414,7 +451,7 @@ python-versions = "*" python-dateutil = ">=2.8.1" [package.extras] -dev = ["wheel", "flake8", "markdown", "twine"] +dev = ["twine", "markdown", "flake8", "wheel"] [[package]] name = "grpcio" @@ -521,6 +558,14 @@ requirements_deprecated_finder = ["pipreqs", "pip-api"] colors = ["colorama (>=0.4.3,<0.5.0)"] plugins = ["setuptools"] +[[package]] +name = "itsdangerous" +version = "2.1.2" +description = "Safely pass data to untrusted environments and back." +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "jinja2" version = "3.1.2" @@ -876,8 +921,8 @@ optional = false python-versions = ">=3.6" [package.extras] -testing = ["pytest-benchmark", "pytest"] -dev = ["tox", "pre-commit"] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" @@ -1106,6 +1151,20 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "python-redis-lock" +version = "4.0.0" +description = "Lock context manager implemented via redis SETNX/BLPOP." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +redis = ">=2.10.0" + +[package.extras] +django = ["django-redis (>=3.8.0)"] + [[package]] name = "pytz" version = "2022.6" @@ -1155,6 +1214,22 @@ python-versions = ">=3.7,<4.0" [package.dependencies] typing-extensions = ">=4.1.1,<5.0.0" +[[package]] +name = "redis" +version = "4.3.5" +description = "Python client for Redis database and key-value store" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +async-timeout = ">=4.0.2" +packaging = ">=20.4" + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "requests" version = "2.28.1" @@ -1173,9 +1248,35 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rq" +version = "1.11.1" +description = "RQ is a simple, lightweight, library for creating background jobs, and processing them." +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +click = ">=5.0.0" +redis = ">=3.5.0" + +[[package]] +name = "rq-dashboard" +version = "0.6.1" +description = "rq-dashboard is a general purpose, lightweight, web interface to monitor your RQ queues, jobs, and workers in realtime." +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +arrow = "*" +Flask = "*" +redis = "*" +rq = ">=1.0" + [[package]] name = "sentry-sdk" -version = "1.11.0" +version = "1.11.1" description = "Python client for Sentry (https://sentry.io)" category = "main" optional = false @@ -1447,7 +1548,7 @@ telegram = ["requests"] [[package]] name = "tritonclient" -version = "2.27.0" +version = "2.28.0" description = "Python client library and utilities for communicating with Triton Inference Server" category = "main" optional = false @@ -1476,10 +1577,10 @@ python-versions = ">=3.6" click = ">=7.1.1,<7.2.0" [package.extras] -test = ["isort (>=5.0.6,<6.0.0)", "black (>=19.10b0,<20.0b0)", "mypy (==0.782)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "coverage (>=5.2,<6.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest (>=4.4.0,<5.4.0)", "shellingham (>=1.3.0,<2.0.0)"] -doc = ["markdown-include (>=0.5.1,<0.6.0)", "mkdocs-material (>=5.4.0,<6.0.0)", "mkdocs (>=1.1.2,<2.0.0)"] -dev = ["flake8 (>=3.8.3,<4.0.0)", "autoflake (>=1.3.1,<2.0.0)"] -all = ["shellingham (>=1.3.0,<2.0.0)", "colorama (>=0.4.3,<0.5.0)"] +test = ["pytest-xdist (>=1.32.0,<2.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "mypy (==0.782)", "black (>=19.10b0,<20.0b0)", "isort (>=5.0.6,<6.0.0)", "shellingham (>=1.3.0,<2.0.0)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "coverage (>=5.2,<6.0)"] +all = ["colorama (>=0.4.3,<0.5.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)"] +doc = ["mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=5.4.0,<6.0.0)", "markdown-include (>=0.5.1,<0.6.0)"] [[package]] name = "typer-cli" @@ -1518,6 +1619,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "types-redis" +version = "4.3.21.6" +description = "Typing stubs for redis" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "types-requests" version = "2.28.11.5" @@ -1682,10 +1791,12 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "flake8 [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "90e4363b144d74816e13198ebddd49e729ada66420449ffbdb2ce910273603d1" +content-hash = "351e4731216306ffab0c42181c810d7a2f356002ee38f086753ac7fbb4ca76fe" [metadata.files] apscheduler = [] +arrow = [] +async-timeout = [] attrs = [] black = [] blis = [] @@ -1717,6 +1828,7 @@ flake8 = [] flake8-bugbear = [] flake8-github-actions = [] flashtext = [] +flask = [] fonttools = [] ghp-import = [] grpcio = [] @@ -1727,6 +1839,7 @@ importlib-metadata = [] influxdb-client = [] iniconfig = [] isort = [] +itsdangerous = [] jinja2 = [] jsonschema = [] kiwisolver = [] @@ -1777,12 +1890,16 @@ pytest-httpserver = [] pytest-mock = [] python-dateutil = [] python-rapidjson = [] +python-redis-lock = [] pytz = [] pytz-deprecation-shim = [] pyyaml = [] pyyaml-env-tag = [] reactivex = [] +redis = [] requests = [] +rq = [] +rq-dashboard = [] sentry-sdk = [] setuptools-scm = [] shellingham = [] @@ -1805,6 +1922,7 @@ typer-cli = [] types-cachetools = [] types-certifi = [] types-protobuf = [] +types-redis = [] types-requests = [] types-setuptools = [] types-six = [] diff --git a/pyproject.toml b/pyproject.toml index e06daf2739..f142731bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,8 @@ py-healthcheck = "^1.10.1" spacy-lookups-data = "^1.0.3" cachetools = "^5.2.0" tritonclient = {extras = ["grpc"], version = "^2.26.0"} +rq = "~1.11.1" +python-redis-lock = "~4.0.0" [tool.poetry.dependencies.sentry-sdk] version = "~1.11.0" @@ -104,6 +106,8 @@ types-setuptools = "^65.6.0.0" types-toml = "^0.10.3" pytest-httpserver = "^1.0.4" types-cachetools = "^5.2.1" +types-redis = "^4.3.21" +rq-dashboard = "~0.6.1" [tool.poetry.scripts] robotoff-cli = 'robotoff.cli.main:main' diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 7fcf75395b..fea691814b 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -63,7 +63,13 @@ from robotoff.utils.i18n import TranslationStore from robotoff.utils.text import get_tag from robotoff.utils.types import JSONType -from robotoff.workers.client import send_ipc_event +from robotoff.workers.queues import enqueue_in_job, enqueue_job, high_queue, low_queue +from robotoff.workers.tasks import ( + delete_product_insights_job, + download_product_dataset_job, + run_import_image_job, + update_insights_job, +) logger = get_logger() @@ -438,7 +444,11 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): class UpdateDatasetResource: def on_post(self, req: falcon.Request, resp: falcon.Response): - send_ipc_event("download_dataset") + """Re-import the Product Opener product dump.""" + + enqueue_job( + download_product_dataset_job, queue=low_queue, job_kwargs={"timeout": "1h"} + ) resp.media = { "status": "scheduled", @@ -458,22 +468,19 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): server_domain = req.get_param("server_domain", required=True) if server_domain != settings.OFF_SERVER_DOMAIN: - logger.info("Rejecting image import from {}".format(server_domain)) + logger.info(f"Rejecting image import from {server_domain}") resp.media = { "status": "rejected", } return - send_ipc_event( - "import_image", - { - "barcode": barcode, - "image_url": image_url, - "ocr_url": ocr_url, - "server_domain": server_domain, - }, + high_queue.enqueue( + run_import_image_job, + barcode=barcode, + image_url=image_url, + ocr_url=ocr_url, + server_domain=server_domain, ) - resp.media = { "status": "scheduled", } @@ -893,19 +900,18 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): ) if action == "updated": - send_ipc_event( - "product_updated", - { - "barcode": barcode, - "server_domain": server_domain, - # add some latency - "task_delay": settings.UPDATED_PRODUCT_WAIT, - }, + enqueue_in_job( + update_insights_job, + high_queue, + settings.UPDATED_PRODUCT_WAIT, + barcode=barcode, + server_domain=server_domain, ) - elif action == "deleted": - send_ipc_event( - "product_deleted", {"barcode": barcode, "server_domain": server_domain} + high_queue.enqueue( + delete_product_insights_job, + barcode=barcode, + server_domain=server_domain, ) resp.media = { diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 668cbcbf2c..cc2d45f7c2 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -1,3 +1,4 @@ +import enum import pathlib import sys from pathlib import Path @@ -10,10 +11,34 @@ @app.command() -def run(service: str) -> None: - from robotoff.cli.run import run as run_ +def run_scheduler(): + """Launch the scheduler service.""" + from robotoff import scheduler + from robotoff.utils import get_logger + + # Defining a root logger + get_logger() + scheduler.run() + - run_(service) +class WorkerQueue(enum.Enum): + robotoff_high = "robotoff-high" + robotoff_low = "robotoff-low" + + +@app.command() +def run_worker( + queues: list[WorkerQueue] = typer.Argument( + ..., help="Names of the queues to listen to" + ), + burst: bool = typer.Option( + False, help="Run in burst mode (quit after all work is done)" + ), +): + """Launch a worker.""" + from robotoff.workers.main import run + + run(queues=[x.value for x in queues], burst=burst) @app.command() @@ -188,8 +213,8 @@ def refresh_insights( None, help="Refresh a specific product. If not provided, all products are updated", ), - server_domain: Optional[str] = typer.Option( - None, help="The server domain to use, Open Food Facts by default" + batch_size: int = typer.Option( + 100, help="Number of products to send in a worker tasks" ), ): """Refresh insights based on available predictions. @@ -197,22 +222,171 @@ def refresh_insights( If a `barcode` is provided, only the insights of this product is refreshed, otherwise insights of all products are refreshed. """ + import tqdm + from more_itertools import chunked + from peewee import fn + from robotoff import settings - from robotoff.insights.importer import refresh_all_insights from robotoff.insights.importer import refresh_insights as refresh_insights_ + from robotoff.models import Prediction as PredictionModel + from robotoff.models import db from robotoff.utils import get_logger + from robotoff.workers.queues import enqueue_job, low_queue + from robotoff.workers.tasks import refresh_insights_job logger = get_logger() - server_domain = server_domain or settings.OFF_SERVER_DOMAIN if barcode is not None: logger.info(f"Refreshing product {barcode}") - imported = refresh_insights_(barcode, server_domain) + imported = refresh_insights_(barcode, settings.OFF_SERVER_DOMAIN) + logger.info(f"Refreshed insights: {imported}") + else: + logger.info("Launching insight refresh on full database") + with db: + barcodes = [ + barcode + for (barcode,) in PredictionModel.select( + fn.Distinct(PredictionModel.barcode) + ).tuples() + ] + + batches = list(chunked(barcodes, batch_size)) + confirm = typer.confirm( + f"{len(batches)} jobs are going to be launched, confirm?" + ) + + if not confirm: + return + + logger.info("Adding refresh_insights jobs in queue...") + for barcode_batch in tqdm.tqdm(batches, desc="barcode batch"): + enqueue_job( + refresh_insights_job, + low_queue, + job_kwargs={"result_ttl": 0, "timeout": "5m"}, + barcodes=barcode_batch, + server_domain=settings.OFF_SERVER_DOMAIN, + ) + + +@app.command() +def import_images_in_db( + batch_size: int = typer.Option( + 500, help="Number of items to send in a worker tasks" + ), +): + """Make sure that every image available in MongoDB is saved in `image` + table.""" + import tqdm + from more_itertools import chunked + + from robotoff import settings + from robotoff.models import ImageModel, db + from robotoff.off import generate_image_path + from robotoff.products import get_product_store + from robotoff.utils import get_logger + from robotoff.workers.queues import enqueue_job, low_queue + from robotoff.workers.tasks.import_image import save_image_job + + logger = get_logger() + + with db: + logger.info("Fetching existing images in DB...") + existing_images = set( + ImageModel.select(ImageModel.barcode, ImageModel.image_id).tuples() + ) + + store = get_product_store() + to_add = [] + for product in tqdm.tqdm( + store.iter_product(projection=["images", "code"]), desc="product" + ): + barcode = product.barcode + for image_id in (id_ for id_ in product.images.keys() if id_.isdigit()): + if (barcode, image_id) not in existing_images: + to_add.append((barcode, generate_image_path(barcode, image_id))) + + batches = list(chunked(to_add, batch_size)) + if typer.confirm( + f"{len(batches)} add image jobs are going to be launched, confirm?" + ): + for batch in tqdm.tqdm(batches, desc="job"): + enqueue_job( + save_image_job, + low_queue, + job_kwargs={"result_ttl": 0}, + batch=batch, + server_domain=settings.OFF_SERVER_DOMAIN, + ) + + +class ObjectDetectionModel(enum.Enum): + nutriscore = "nutriscore" + universal_logo_detector = "universal-logo-detector" + nutrition_table = "nutrition-table" + + +@app.command() +def run_object_detection_model( + model_name: ObjectDetectionModel = typer.Argument( + ..., help="Name of the object detection model" + ), + limit: Optional[int] = typer.Option(None, help="Maximum numbers of job to launch"), +): + """Run universal-logo-detector and nutrition-table object detection models + on all images in DB.""" + import tqdm + from peewee import JOIN + + from robotoff import settings + from robotoff.models import ImageModel, ImagePrediction, db + from robotoff.off import generate_image_url + from robotoff.workers.queues import enqueue_job, low_queue + from robotoff.workers.tasks.import_image import ( + run_logo_object_detection, + run_nutriscore_object_detection, + run_nutrition_table_object_detection, + ) + + if model_name == ObjectDetectionModel.universal_logo_detector: + func = run_logo_object_detection + elif model_name == ObjectDetectionModel.nutrition_table: + func = run_nutrition_table_object_detection else: - logger.info("Refreshing insights of all products") - imported = refresh_all_insights(server_domain) + func = run_nutriscore_object_detection - logger.info(f"Refreshed insights: {imported}") + with db: + query = ( + ImageModel.select(ImageModel.barcode, ImageModel.id) + .join( + ImagePrediction, + JOIN.LEFT_OUTER, + on=( + (ImagePrediction.image_id == ImageModel.id) + & (ImagePrediction.model_name == model_name) + ), + ) + .where(ImagePrediction.model_name.is_null()) + .tuples() + ) + if limit: + query = query.limit(limit) + missing_items = list(query) + + if limit: + missing_items = missing_items[:limit] + + if typer.confirm(f"{len(missing_items)} jobs are going to be launched, confirm?"): + for barcode, image_id in tqdm.tqdm(missing_items, desc="image"): + image_url = generate_image_url(barcode, image_id) + enqueue_job( + func, + low_queue, + job_kwargs={"result_ttl": 0}, + barcode=barcode, + image_url=image_url, + server_domain=settings.OFF_SERVER_DOMAIN, + ) @app.command() diff --git a/robotoff/cli/run.py b/robotoff/cli/run.py deleted file mode 100644 index 50d88cf310..0000000000 --- a/robotoff/cli/run.py +++ /dev/null @@ -1,33 +0,0 @@ -import subprocess - -import click - -from robotoff import settings - - -def run(service: str): - if service == "api": - subprocess.run( - [ - "gunicorn", - "--config", - str(settings.PROJECT_DIR / "gunicorn.py"), - "robotoff.app.api:api", - ] - ) - - elif service == "workers": - from robotoff.workers import listener - - listener.run() - - elif service == "scheduler": - from robotoff import scheduler - from robotoff.utils import get_logger - - # Defining a root logger - get_logger() - scheduler.run() - - else: - click.echo("invalid service: '{}'".format(service), err=True) diff --git a/robotoff/health.py b/robotoff/health.py index 17e624e2f5..2e4db73aab 100644 --- a/robotoff/health.py +++ b/robotoff/health.py @@ -4,6 +4,7 @@ from playhouse.postgres_ext import PostgresqlExtDatabase from pymongo import MongoClient from pymongo.errors import ServerSelectionTimeoutError +from redis import Redis from robotoff import settings from robotoff.utils import get_logger @@ -23,6 +24,13 @@ def test_connect_mongodb(): return True, "MongoDB DB connection succeeded!" +def test_connect_redis(): + logger.debug("health: testing Redis connection to %s", settings.REDIS_HOST) + client = Redis(host=settings.REDIS_HOST) + client.ping() + return True, "Redis DB connection success!" + + def test_connect_postgres(): logger.debug("health: testing postgres connection to %s", settings.POSTGRES_HOST) client = PostgresqlExtDatabase( @@ -67,5 +75,6 @@ def test_connect_ann(): health.add_check(test_connect_mongodb) health.add_check(test_connect_postgres) health.add_check(test_connect_influxdb) +health.add_check(test_connect_redis) health.add_check(test_connect_robotoff_api) health.add_check(test_connect_ann) diff --git a/robotoff/insights/extraction.py b/robotoff/insights/extraction.py index b822b81698..90cda6648d 100644 --- a/robotoff/insights/extraction.py +++ b/robotoff/insights/extraction.py @@ -1,10 +1,16 @@ -from typing import Dict, Iterable, List, Optional +import datetime +from typing import Iterable, List, Optional from PIL import Image +from robotoff.models import ImageModel, ImagePrediction from robotoff.off import get_source_from_url from robotoff.prediction import ocr -from robotoff.prediction.object_detection import ObjectDetectionModelRegistry +from robotoff.prediction.object_detection import ( + OBJECT_DETECTION_MODEL_VERSION, + ObjectDetectionModel, + ObjectDetectionModelRegistry, +) from robotoff.prediction.ocr.core import get_ocr_result from robotoff.prediction.types import Prediction, PredictionType from robotoff.utils import get_logger, http_session @@ -35,6 +41,61 @@ ] +def run_object_detection_model( + model_name: ObjectDetectionModel, + image: Image.Image, + source_image: str, + threshold: float = 0.1, +) -> Optional[ImagePrediction]: + """Run a model detection model and save the results in the + `image_prediction` table. + + An item with the corresponding `source_image` in the `image` table is + expected to exist. Nothing is done if an image prediction already exists + in DB for this image and model. + + :param model_name: name of the object detection model to use + :param image: the input Pillow image + :param source_image: the source image path (used to fetch the image from + `image` table) + :param threshold: the minimum object score above which we keep the object data + + :return: return None if the image does not exist in DB, or the created + `ImagePrediction` otherwise + """ + image_instance = ImageModel.get_or_none(source_image=source_image) + + if image_instance is None: + logger.warning("Missing image in DB for image %s", source_image) + return None + + existing_image_prediction = ImagePrediction.get_or_none( + image=image_instance, model_name=model_name.value + ) + if existing_image_prediction is not None: + logger.info( + f"Object detection results for {model_name} already exist for " + f"image {source_image}: ID {existing_image_prediction.id}" + ) + return None + + timestamp = datetime.datetime.utcnow() + results = ObjectDetectionModelRegistry.get(model_name.value).detect_from_image( + image, output_image=False + ) + data = results.to_json(threshold=threshold) + max_confidence = max([item["score"] for item in data], default=None) + return ImagePrediction.create( + image=image_instance, + type="object_detection", + model_name=model_name.value, + model_version=OBJECT_DETECTION_MODEL_VERSION[model_name], + data={"objects": data}, + timestamp=timestamp, + max_confidence=max_confidence, + ) + + def get_predictions_from_product_name( barcode: str, product_name: str ) -> List[Prediction]: @@ -53,49 +114,11 @@ def get_predictions_from_product_name( return predictions_all -def get_predictions_from_image( - barcode: str, image: Image.Image, source_image: str, ocr_url: str -) -> List[Prediction]: - logger.info(f"Generating OCR predictions from OCR {ocr_url}") - ocr_predictions = extract_ocr_predictions( - barcode, ocr_url, DEFAULT_OCR_PREDICTION_TYPES - ) - extract_nutriscore = any( - prediction.value_tag == "en:nutriscore" - and prediction.type == PredictionType.label - for prediction in ocr_predictions - ) - image_ml_predictions = extract_image_ml_predictions( - barcode, image, source_image, extract_nutriscore=extract_nutriscore - ) - return ocr_predictions + image_ml_predictions - - -def extract_image_ml_predictions( - barcode: str, image: Image.Image, source_image: str, extract_nutriscore: bool = True -) -> List[Prediction]: - if extract_nutriscore: - # Currently all of the automatic processing for the Nutri-Score grades has been - # disabled due to a prediction quality issue. - # Last automatic processing threshold was set to 0.9 - resulting in ~70% incorrect - # detection. - nutriscore_prediction = extract_nutriscore_label( - image, - source_image, - manual_threshold=0.5, - ) - - if nutriscore_prediction: - nutriscore_prediction.barcode = barcode - nutriscore_prediction.source_image = source_image - return [nutriscore_prediction] - - return [] - - def extract_ocr_predictions( barcode: str, ocr_url: str, prediction_types: Iterable[PredictionType] ) -> List[Prediction]: + logger.info(f"Generating OCR predictions from OCR {ocr_url}") + predictions_all: List[Prediction] = [] source_image = get_source_from_url(ocr_url) ocr_result = get_ocr_result(ocr_url, http_session, error_raise=False) @@ -109,51 +132,3 @@ def extract_ocr_predictions( ) return predictions_all - - -NUTRISCORE_LABELS: Dict[str, str] = { - "nutriscore-a": "en:nutriscore-grade-a", - "nutriscore-b": "en:nutriscore-grade-b", - "nutriscore-c": "en:nutriscore-grade-c", - "nutriscore-d": "en:nutriscore-grade-d", - "nutriscore-e": "en:nutriscore-grade-e", -} - - -def extract_nutriscore_label( - image: Image.Image, - source_image: str, - manual_threshold: float, - automatic_threshold: Optional[float] = None, -) -> Optional[Prediction]: - model = ObjectDetectionModelRegistry.get("nutriscore") - raw_result = model.detect_from_image(image, output_image=False) - results = raw_result.select(threshold=manual_threshold) - - if not results: - return None - - if len(results) > 1: - logger.warning("more than one nutriscore detected, discarding detections") - return None - - result = results[0] - score = result.score - - automatic_processing = False - if automatic_threshold: - automatic_processing = score >= automatic_threshold - label_tag = NUTRISCORE_LABELS[result.label] - - return Prediction( - type=PredictionType.label, - source_image=source_image, - value_tag=label_tag, - automatic_processing=automatic_processing, - data={ - "confidence": score, - "bounding_box": result.bounding_box, - "model": "nutriscore", - "notify": True, - }, - ) diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index f25362d527..ee78783056 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type -from peewee import fn from playhouse.shortcuts import model_to_dict from robotoff import settings @@ -24,6 +23,7 @@ get_product_store, is_valid_image, ) +from robotoff.redis import Lock, LockedResourceException from robotoff.taxonomy import ( Taxonomy, TaxonomyType, @@ -275,6 +275,7 @@ def get_required_prediction_types() -> Set[PredictionType]: @classmethod def import_insights( cls, + barcode: str, predictions: List[Prediction], server_domain: str, product_store: DBProductStore, @@ -288,37 +289,48 @@ def import_insights( if prediction.type not in required_prediction_types: raise ValueError(f"unexpected prediction type: '{prediction.type}'") - inserts = 0 - for to_create, to_update, to_delete in cls.generate_insights( - predictions, server_domain, product_store - ): - if to_delete: - to_delete_ids = [insight.id for insight in to_delete] - logger.info(f"Deleting {len(to_delete_ids)} insights") - ProductInsight.delete().where( - ProductInsight.id.in_(to_delete_ids) - ).execute() - if to_create: - inserts += batch_insert( - ProductInsight, - (model_to_dict(insight) for insight in to_create), - 50, + if ( + len( + prediction_barcodes := set( + prediction.barcode for prediction in predictions ) + ) + > 1 + ): + raise ValueError( + f"predictions for more than 1 product were provided: {prediction_barcodes}" + ) - for insight in to_update: - insight.save() + inserts = 0 + to_create, to_update, to_delete = cls.generate_insights( + barcode, predictions, server_domain, product_store + ) + if to_delete: + to_delete_ids = [insight.id for insight in to_delete] + logger.info(f"Deleting {len(to_delete_ids)} insights") + ProductInsight.delete().where( + ProductInsight.id.in_(to_delete_ids) + ).execute() + if to_create: + inserts += batch_insert( + ProductInsight, + (model_to_dict(insight) for insight in to_create), + 50, + ) + + for insight in to_update: + insight.save() return inserts @classmethod def generate_insights( cls, + barcode: str, predictions: List[Prediction], server_domain: str, product_store: DBProductStore, - ) -> Iterator[ - Tuple[List[ProductInsight], List[ProductInsight], List[ProductInsight]] - ]: + ) -> Tuple[List[ProductInsight], List[ProductInsight], List[ProductInsight]]: """Given a list of predictions, yield tuples of ProductInsight to create, update and delete. @@ -328,65 +340,57 @@ def generate_insights( timestamp = datetime.datetime.utcnow() server_type = get_server_type(server_domain).name - for barcode, group in itertools.groupby( - sorted(predictions, key=operator.attrgetter("barcode")), - operator.attrgetter("barcode"), - ): - product = product_store[barcode] - references = get_existing_insight(cls.get_type(), barcode, server_domain) + product = product_store[barcode] + references = get_existing_insight(cls.get_type(), barcode, server_domain) - if product is None: - logger.info( - f"Product {barcode} not found in DB, deleting existing insights" - ) - if references: - yield [], [], references - continue - - product_predictions = sort_predictions(group) - candidates = [ - candidate - for candidate in cls.generate_candidates(product, product_predictions) - if is_valid_insight_image(product.images, candidate.source_image) - ] - for candidate in candidates: - if candidate.automatic_processing is None: - logger.warning( - "Insight with automatic_processing=None: %s", candidate.__data__ - ) - candidate.automatic_processing = False - - if not is_trustworthy_insight_image( - product.images, candidate.source_image - ): - # Don't process automatically if the insight image is not - # trustworthy (too old and not selected) - candidate.automatic_processing = False - if candidate.data.get("is_annotation"): - username = candidate.data.get("username") - if username: - # logo annotation by a user - candidate.username = username - # Note: we could add vote annotation for anonymous user, - # but it should be done outside this loop. It's not yet implemented - - to_create, to_update, to_delete = cls.get_insight_update( - candidates, references + if product is None: + logger.info( + f"Product {barcode} not found in DB, deleting existing insights" ) + return [], [], references - for insight in to_create: - cls.add_fields(insight, product, timestamp, server_domain, server_type) - - for insight, reference_insight in to_update: - # Keep `reference_insight` in DB (as the value/value_tag/source_image is the same), - # but update information from `insight`. - # This way, we don't unnecessarily insert/delete rows in ProductInsight table - # and we keep associated votes - cls.update_fields(insight, reference_insight, product, timestamp) - - yield to_create, [ - reference_insight for (_, reference_insight) in to_update - ], to_delete + predictions = sort_predictions(predictions) + candidates = [ + candidate + for candidate in cls.generate_candidates(product, predictions) + if is_valid_insight_image(product.images, candidate.source_image) + ] + for candidate in candidates: + if candidate.automatic_processing is None: + logger.warning( + "Insight with automatic_processing=None: %s", candidate.__data__ + ) + candidate.automatic_processing = False + + if not is_trustworthy_insight_image(product.images, candidate.source_image): + # Don't process automatically if the insight image is not + # trustworthy (too old and not selected) + candidate.automatic_processing = False + if candidate.data.get("is_annotation"): + username = candidate.data.get("username") + if username: + # logo annotation by a user + candidate.username = username + # Note: we could add vote annotation for anonymous user, + # but it should be done outside this loop. It's not yet implemented + + to_create, to_update, to_delete = cls.get_insight_update(candidates, references) + + for insight in to_create: + cls.add_fields(insight, product, timestamp, server_domain, server_type) + + for insight, reference_insight in to_update: + # Keep `reference_insight` in DB (as the value/value_tag/source_image is the same), + # but update information from `insight`. + # This way, we don't unnecessarily insert/delete rows in ProductInsight table + # and we keep associated votes + cls.update_fields(insight, reference_insight, product, timestamp) + + return ( + to_create, + [reference_insight for (_, reference_insight) in to_update], + to_delete, + ) @classmethod @abc.abstractmethod @@ -1031,6 +1035,11 @@ def import_insights( server_domain: str, product_store: Optional[DBProductStore] = None, ) -> int: + """Import predictions and generate (and import) insights from these + predictions. + + :param predictions: an iterable of Predictions to import + """ if product_store is None: product_store = get_product_store() @@ -1051,6 +1060,9 @@ def import_insights_for_products( :param prediction_types_by_barcode: a dict that associates each barcode with a set of prediction type that were updated + :param server_domain: The server domain associated with the predictions + :param product_store: The product store to use + :return: Number of imported insights """ imported = 0 @@ -1068,9 +1080,25 @@ def import_insights_for_products( selected_barcodes, list(required_prediction_types) ) ] - imported += importer.import_insights( - predictions, server_domain, product_store - ) + + for barcode, product_predictions in itertools.groupby( + sorted(predictions, key=operator.attrgetter("barcode")), + operator.attrgetter("barcode"), + ): + try: + with Lock(name=f"robotoff:import:{barcode}", expire=60, timeout=10): + imported += importer.import_insights( + barcode, + list(product_predictions), + server_domain, + product_store, + ) + except LockedResourceException: + logger.info( + "Couldn't acquire insight import lock, skipping insight import for product %s", + barcode, + ) + continue return imported @@ -1141,6 +1169,7 @@ def refresh_insights( required_prediction_types = importer.get_required_prediction_types() if prediction_types >= required_prediction_types: imported += importer.import_insights( + barcode, [p for p in predictions if p.type in required_prediction_types], server_domain, product_store, @@ -1149,26 +1178,6 @@ def refresh_insights( return imported -def refresh_all_insights( - server_domain: str, - product_store: Optional[DBProductStore] = None, -): - """Refresh insights of all products for which we have predictions. - - :param server_domain: The server domain associated with the predictions. - :param product_store: The product store to use, defaults to None - :return: The number of imported insights. - """ - imported = 0 - for (barcode,) in ( - PredictionModel.select(fn.Distinct(PredictionModel.barcode)).tuples().iterator() - ): - logger.info(f"Refreshing insights for product {barcode}") - imported += refresh_insights(barcode, server_domain, product_store) - - return imported - - def get_product_predictions( barcodes: List[str], prediction_types: Optional[List[str]] = None ) -> Iterator[Dict]: diff --git a/robotoff/logos.py b/robotoff/logos.py index a32a8e1aad..ee518e7993 100644 --- a/robotoff/logos.py +++ b/robotoff/logos.py @@ -1,6 +1,7 @@ import operator from typing import Dict, List, Optional, Set, Tuple +import cachetools import numpy as np from robotoff import settings @@ -10,7 +11,6 @@ from robotoff.prediction.types import Prediction, PredictionType from robotoff.slack import NotifierFactory from robotoff.utils import get_logger, http_session -from robotoff.utils.cache import CachedStore from robotoff.utils.types import JSONType logger = get_logger(__name__) @@ -27,6 +27,13 @@ BoundingBoxType = Tuple[float, float, float, float] +def load_resources(): + """Load and cache resources.""" + logger.info("Loading logo resources...") + get_logo_confidence_thresholds() + get_logo_annotations() + + def compute_iou(box_1: BoundingBoxType, box_2: BoundingBoxType) -> float: """Compute the IoU (intersection over union) for two bounding boxes. @@ -77,6 +84,7 @@ def filter_logos( return filtered +@cachetools.cached(cachetools.LRUCache(maxsize=1)) def get_logo_confidence_thresholds() -> Dict[LogoLabelType, float]: thresholds = {} @@ -86,11 +94,6 @@ def get_logo_confidence_thresholds() -> Dict[LogoLabelType, float]: return thresholds -LOGO_CONFIDENCE_THRESHOLDS = CachedStore( - get_logo_confidence_thresholds, expiration_interval=10 -) - - def get_stored_logo_ids() -> Set[int]: r = http_session.get( settings.BaseURLProvider().robotoff().get() + "/api/v1/ann/stored", timeout=30 @@ -164,6 +167,7 @@ def save_nearest_neighbors(logos: List[LogoAnnotation]) -> int: return saved +@cachetools.cached(cachetools.LRUCache(maxsize=1)) def get_logo_annotations() -> Dict[int, LogoLabelType]: annotations: Dict[int, LogoLabelType] = {} @@ -185,9 +189,6 @@ def get_logo_annotations() -> Dict[int, LogoLabelType]: return annotations -LOGO_ANNOTATIONS_CACHE = CachedStore(get_logo_annotations, expiration_interval=1) - - def predict_label(logo: LogoAnnotation) -> Optional[LogoLabelType]: probs = predict_proba(logo) @@ -206,7 +207,7 @@ def predict_proba( nn_distances = logo.nearest_neighbors["distances"] nn_logo_ids = logo.nearest_neighbors["logo_ids"] - logo_annotations = LOGO_ANNOTATIONS_CACHE.get() + logo_annotations = get_logo_annotations() nn_labels: List[LogoLabelType] = [] for nn_logo_id in nn_logo_ids: diff --git a/robotoff/models.py b/robotoff/models.py index dcfcee869c..9476ba19b2 100644 --- a/robotoff/models.py +++ b/robotoff/models.py @@ -16,6 +16,7 @@ password=settings.POSTGRES_PASSWORD, host=settings.POSTGRES_HOST, port=5432, + autoconnect=False, ) @@ -25,10 +26,9 @@ def with_db(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): with db: - # use atomic to avoid falling in a bad state + # use atomic transaction to avoid falling in a bad state # (error in the main transaction) - with db.atomic(): - return fn(*args, **kwargs) + return fn(*args, **kwargs) return wrapper @@ -220,8 +220,8 @@ class ImagePrediction(BaseModel): They are created by API `ImagePredictorResource`, `ImagePredictionImporterResource` or cli `import_logos` - Predictions come from a model, see `OBJECT_DETECTION_MODEL_VERSION` in - settings.py for available models. + Predictions come from a model, see `ObjectDetectionModel` in + predictions/object_detection/core.py for available models. """ type = peewee.CharField(max_length=256) @@ -242,7 +242,7 @@ class LogoAnnotation(BaseModel): """Annotation(s) for an image prediction (an image prediction might lead to several annotations) - At the moment, this is mostly for logo (see run_object_detection), + At the moment, this is mostly for logo (see run_logo_object_detection), when we have a logo prediction above a certain threshold we create an entry, to ask user for annotation on the logo (https://hunger.openfoodfacts.org/logos) and eventual annotation will land there. diff --git a/robotoff/prediction/category/matcher.py b/robotoff/prediction/category/matcher.py index fdbaf54c4a..86e98e62ba 100644 --- a/robotoff/prediction/category/matcher.py +++ b/robotoff/prediction/category/matcher.py @@ -95,6 +95,17 @@ MatchMapType = Dict[str, Dict[str, List[Tuple[str, str]]]] +def load_resources(): + """Load and cache resources.""" + logger.info("Loading matcher resources...") + get_processors() + get_intersect_categories_ingredients() + + for lang in SUPPORTED_LANG: + logger.info(f"Loading NLP for {lang}...") + get_lemmatizing_nlp(lang) + + def preprocess_product_name(name: str, lang: str) -> str: """Preprocess product name before matching: - remove all weight mentions (100 g, 1l,...) @@ -224,6 +235,7 @@ def get_processors() -> Dict[str, KeywordProcessor]: This enables a fast matching of query parts against matched maps keys. """ + logger.info("Loading category matcher processors...") match_maps = get_match_maps(TaxonomyType.category.name) processors = {} for lang, items in match_maps.items(): @@ -267,6 +279,7 @@ def get_intersect_categories_ingredients(): See `generate_intersect_categories_ingredients` function for more information. """ + logger.info("Loading category intersection ingredient...") return { k: set(v) for k, v in load_json( diff --git a/robotoff/prediction/category/neural/category_classifier.py b/robotoff/prediction/category/neural/category_classifier.py index 08fdb9827c..9a06d0785f 100644 --- a/robotoff/prediction/category/neural/category_classifier.py +++ b/robotoff/prediction/category/neural/category_classifier.py @@ -84,7 +84,9 @@ def predict( } r = http_session.post( - f"{settings.TF_SERVING_BASE_URL}/category-classifier:predict", json=data + f"{settings.TF_SERVING_BASE_URL}/category-classifier:predict", + json=data, + timeout=(3.0, 10.0), ) r.raise_for_status() response = r.json() diff --git a/robotoff/prediction/object_detection/__init__.py b/robotoff/prediction/object_detection/__init__.py index a372d78db7..61d853dbf0 100644 --- a/robotoff/prediction/object_detection/__init__.py +++ b/robotoff/prediction/object_detection/__init__.py @@ -1,2 +1,7 @@ # flake8: noqa -from .core import ObjectDetectionModelRegistry, ObjectDetectionRawResult +from .core import ( + OBJECT_DETECTION_MODEL_VERSION, + ObjectDetectionModel, + ObjectDetectionModelRegistry, + ObjectDetectionRawResult, +) diff --git a/robotoff/prediction/object_detection/core.py b/robotoff/prediction/object_detection/core.py index afa33fad9d..7586e07a58 100644 --- a/robotoff/prediction/object_detection/core.py +++ b/robotoff/prediction/object_detection/core.py @@ -1,4 +1,5 @@ import dataclasses +import enum import pathlib from typing import Dict, List, Optional, Tuple @@ -18,6 +19,19 @@ LABEL_NAMES_FILENAME = "labels.txt" +class ObjectDetectionModel(enum.Enum): + nutriscore = "nutriscore" + universal_logo_detector = "universal-logo-detector" + nutrition_table = "nutrition-table" + + +OBJECT_DETECTION_MODEL_VERSION = { + ObjectDetectionModel.nutriscore: "tf-nutriscore-1.0", + ObjectDetectionModel.nutrition_table: "tf-nutrition-table-1.0", + ObjectDetectionModel.universal_logo_detector: "tf-universal-logo-detector-1.0", +} + + @dataclasses.dataclass class ObjectDetectionResult: bounding_box: Tuple @@ -201,7 +215,8 @@ def get_available_models(cls) -> List[str]: def load_all(cls): if cls._loaded: return - for model_name in settings.OBJECT_DETECTION_MODEL_VERSION: + for model in ObjectDetectionModel: + model_name = model.value file_path = settings.MODELS_DIR / model_name if file_path.is_dir(): logger.info(f"Model '{model_name}' found") diff --git a/robotoff/products.py b/robotoff/products.py index 3efb9c8321..f23c6b7ad9 100644 --- a/robotoff/products.py +++ b/robotoff/products.py @@ -493,7 +493,10 @@ def __getitem__(self, barcode: str) -> Optional[Product]: return None def __iter__(self): - raise NotImplementedError("cannot iterate over database product store") + yield from self.iter() + + def iter_product(self, projection: Optional[list[str]] = None): + yield from (Product(p) for p in self.collection.find(projection=projection)) def load_min_dataset() -> ProductStore: diff --git a/robotoff/redis.py b/robotoff/redis.py new file mode 100644 index 0000000000..360b465679 --- /dev/null +++ b/robotoff/redis.py @@ -0,0 +1,42 @@ +from typing import Optional + +from redis import Redis +from redis_lock import Lock as BaseLock + +from robotoff import settings + +redis_conn = Redis(host=settings.REDIS_HOST) + + +class LockedResourceException(Exception): + pass + + +class Lock(BaseLock): + _enabled = True + + def __init__( + self, + name: str, + blocking: bool = False, + timeout: Optional[float] = None, + expire: int = 60, + **kwargs, + ): + self.timeout = timeout + if timeout is not None: + blocking = True + self.blocking = blocking + if self._enabled: + super().__init__(redis_conn, name=name, expire=expire, **kwargs) + + def __enter__(self): + if self._enabled: + acquired = self.acquire(blocking=self.blocking, timeout=self.timeout) + if not acquired: + raise LockedResourceException() + return self + + def __exit__(self, *args, **kwargs): + if self._enabled: + self.release() diff --git a/robotoff/scheduler/__init__.py b/robotoff/scheduler/__init__.py index 4868313a7d..a259d2699e 100644 --- a/robotoff/scheduler/__init__.py +++ b/robotoff/scheduler/__init__.py @@ -23,7 +23,7 @@ save_facet_metrics, save_insight_metrics, ) -from robotoff.models import ProductInsight, with_db +from robotoff.models import ProductInsight, db, with_db from robotoff.prediction.category.matcher import predict_from_dataset from robotoff.products import ( CACHED_PRODUCT_STORE, @@ -44,38 +44,39 @@ # Note: we do not use with_db, for atomicity is handled in annotator def process_insights(): - processed = 0 - for insight in ( - ProductInsight.select() - .where( - ProductInsight.annotation.is_null(), - ProductInsight.process_after.is_null(False), - ProductInsight.process_after <= datetime.datetime.utcnow(), - ) - .iterator() - ): - try: - annotator = InsightAnnotatorFactory.get(insight.type) - logger.info( - "Annotating insight %s (product: %s)", insight.id, insight.barcode + with db.connection_context(): + processed = 0 + for insight in ( + ProductInsight.select() + .where( + ProductInsight.annotation.is_null(), + ProductInsight.process_after.is_null(False), + ProductInsight.process_after <= datetime.datetime.utcnow(), ) - annotation_result = annotator.annotate(insight, 1, update=True) - processed += 1 - - if annotation_result == UPDATED_ANNOTATION_RESULT and insight.data.get( - "notify", False - ): - slack.NotifierFactory.get_notifier().notify_automatic_processing( - insight + .iterator() + ): + try: + annotator = InsightAnnotatorFactory.get(insight.type) + logger.info( + "Annotating insight %s (product: %s)", insight.id, insight.barcode + ) + annotation_result = annotator.annotate(insight, 1, update=True) + processed += 1 + + if annotation_result == UPDATED_ANNOTATION_RESULT and insight.data.get( + "notify", False + ): + slack.NotifierFactory.get_notifier().notify_automatic_processing( + insight + ) + except Exception as e: + # continue to the next one + # Note: annotator already rolled-back the transaction + logger.exception( + f"exception {e} while handling annotation of insight %s (product) %s", + insight.id, + insight.barcode, ) - except Exception as e: - # continue to the next one - # Note: annotator already rolled-back the transaction - logger.exception( - f"exception {e} while handling annotation of insight %s (product) %s", - insight.id, - insight.barcode, - ) logger.info("%d insights processed", processed) @@ -228,10 +229,11 @@ def generate_insights(): dataset = ProductDataset(settings.JSONL_DATASET_PATH) product_predictions_iter = predict_from_dataset(dataset, datetime_threshold) - imported = import_insights( - product_predictions_iter, server_domain=settings.OFF_SERVER_DOMAIN - ) - logger.info("{} category insights imported".format(imported)) + with db: + imported = import_insights( + product_predictions_iter, server_domain=settings.OFF_SERVER_DOMAIN + ) + logger.info(f"{imported} category insights imported") def transform_insight_iter(insights_iter: Iterable[Dict]): diff --git a/robotoff/scheduler/latent.py b/robotoff/scheduler/latent.py index 696e7e0fa1..83edaca9d0 100644 --- a/robotoff/scheduler/latent.py +++ b/robotoff/scheduler/latent.py @@ -4,7 +4,7 @@ from robotoff import settings from robotoff.insights.dataclass import InsightType -from robotoff.models import Prediction, ProductInsight +from robotoff.models import Prediction, ProductInsight, with_db from robotoff.off import get_server_type from robotoff.prediction.types import PredictionType from robotoff.products import ( @@ -30,6 +30,7 @@ def generate_quality_facets(): generate_fiber_quality_facet() +@with_db def generate_fiber_quality_facet(): product_store: DBProductStore = get_product_store() collection = product_store.collection diff --git a/robotoff/settings.py b/robotoff/settings.py index a82b132b44..3e56e148d7 100644 --- a/robotoff/settings.py +++ b/robotoff/settings.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import sentry_sdk from sentry_sdk.integrations import Integration @@ -170,11 +170,9 @@ def off_credentials() -> Dict[str, str]: MONGO_URI = os.environ.get("MONGO_URI", "mongodb://mongodb:27017") -IPC_AUTHKEY = os.environ.get("IPC_AUTHKEY", "IPC").encode("utf-8") -IPC_HOST = os.environ.get("IPC_HOST", "localhost") -IPC_PORT = int(os.environ.get("IPC_PORT", 6650)) -IPC_ADDRESS: Tuple[str, int] = (IPC_HOST, IPC_PORT) -WORKER_COUNT = int(os.environ.get("WORKER_COUNT", 8)) +# Redis +REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") + # how many seconds should we wait to compute insight on product updated UPDATED_PRODUCT_WAIT = float(os.environ.get("ROBOTOFF_UPDATED_PRODUCT_WAIT", 10)) @@ -281,12 +279,6 @@ def init_sentry(integrations: Optional[List[Integration]] = None): OBJECT_DETECTION_IMAGE_MAX_SIZE = (1024, 1024) -OBJECT_DETECTION_MODEL_VERSION = { - "nutriscore": "tf-nutriscore-1.0", - "nutrition-table": "tf-nutrition-table-1.0", - "universal-logo-detector": "tf-universal-logo-detector-1.0", -} - # We require a minimum of 15 occurences of the brands already on OFF to perform the extraction. This reduces false positive. # We require a minimum of 4 characters for the brand diff --git a/robotoff/taxonomy.py b/robotoff/taxonomy.py index 0a7c6024ce..7e5975b544 100644 --- a/robotoff/taxonomy.py +++ b/robotoff/taxonomy.py @@ -357,3 +357,13 @@ def match_taxonomized_value(value_tag: str, taxonomy_type: str) -> Optional[str] return value_tag return get_taxonomy_mapping(taxonomy_type).get(value_tag) + + +def load_resources(): + """Load and cache resources.""" + logger.info("Loading taxonomy resources...") + for taxonomy_type in settings.TAXONOMY_URLS.keys(): + get_taxonomy(taxonomy_type) + + for taxonomy_type in (TaxonomyType.brand, TaxonomyType.label): + get_taxonomy_mapping(taxonomy_type.name) diff --git a/robotoff/workers/client.py b/robotoff/workers/client.py deleted file mode 100644 index c3480f4090..0000000000 --- a/robotoff/workers/client.py +++ /dev/null @@ -1,19 +0,0 @@ -from multiprocessing.connection import Client -from typing import Dict, Optional - -from robotoff import settings -from robotoff.utils import get_logger - -logger = get_logger(__name__) - - -def send_ipc_event(event_type: str, meta: Optional[Dict] = None): - meta = meta or {} - - logger.info("Connecting listener server on {}:{}" "".format(*settings.IPC_ADDRESS)) - with Client( # type: ignore - settings.IPC_ADDRESS, authkey=settings.IPC_AUTHKEY, family="AF_INET" - ) as conn: - logger.info("Sending event through IPC") - conn.send({"type": event_type, "meta": meta}) - logger.info("IPC event sent") diff --git a/robotoff/workers/listener.py b/robotoff/workers/listener.py deleted file mode 100644 index 297d162cc7..0000000000 --- a/robotoff/workers/listener.py +++ /dev/null @@ -1,57 +0,0 @@ -import threading -import time -from multiprocessing.connection import Listener -from multiprocessing.pool import Pool -from typing import Dict - -from sentry_sdk import capture_exception - -from robotoff import settings -from robotoff.utils import get_logger -from robotoff.workers.tasks import run_task - -settings.init_sentry() - -logger = get_logger() - - -def send_task_to_pool(pool, event_type, event_kwargs, delay): - """Simply pass the task to a worker in the pool, while eventually applying a delay""" - if delay: - time.sleep(delay) - logger.debug("Sending task to pool...") - pool.apply_async(run_task, (event_type, event_kwargs)) - logger.debug("Task sent") - - -def run(): - """This is the event listener, it will receive task requests and launch them""" - pool: Pool = Pool(settings.WORKER_COUNT, maxtasksperchild=30) - - logger.info("Starting listener server on {}:{}".format(*settings.IPC_ADDRESS)) - logger.info("Starting listener server") - - with Listener( - settings.IPC_ADDRESS, authkey=settings.IPC_AUTHKEY, family="AF_INET" - ) as listener: - while True: - try: - logger.debug("Waiting for a connection...") - - with listener.accept() as conn: - event = conn.recv() - event_type: str = event["type"] - logger.info(f"New '{event_type}' event received") - event_kwargs: Dict = event.get("meta", {}) - - delay = event_kwargs.pop("task_delay", None) - args = [pool, event_type, event_kwargs, delay] - if delay: - # we have a delay, so spend it in a thread instead of listener main thread - threading.Thread(target=send_task_to_pool, args=args).start() - else: - # direct call, it's fast - send_task_to_pool(*args) - - except Exception: - capture_exception() diff --git a/robotoff/workers/main.py b/robotoff/workers/main.py new file mode 100644 index 0000000000..b08e664376 --- /dev/null +++ b/robotoff/workers/main.py @@ -0,0 +1,52 @@ +import sys + +from rq import Connection, Worker + +from robotoff import settings +from robotoff.models import with_db +from robotoff.utils import get_logger +from robotoff.workers.queues import redis_conn + +logger = get_logger() +settings.init_sentry() + + +@with_db +def load_resources(refresh: bool = False): + """Load cacheable resources in memory. + + This way, all resources are available in memory before the worker forks. + """ + if refresh: + logger.info("Refreshing worker resource caches...") + else: + logger.info("Loading resources in memory...") + + from robotoff import logos, taxonomy + from robotoff.prediction.category import matcher + from robotoff.prediction.object_detection import ObjectDetectionModelRegistry + + matcher.load_resources() + taxonomy.load_resources() + logos.load_resources() + + if not refresh: + logger.info("Loading object detection model labels...") + ObjectDetectionModelRegistry.load_all() + + +class CustomWorker(Worker): + def run_maintenance_tasks(self): + super().run_maintenance_tasks() + load_resources(refresh=True) + + +def run(queues: list[str], burst: bool = False): + load_resources() + try: + with Connection(connection=redis_conn): + w = CustomWorker(queues=queues) + w.work(logging_level="INFO", burst=burst) + except ConnectionError as e: + print(e) + sys.exit(1) diff --git a/robotoff/workers/queues.py b/robotoff/workers/queues.py new file mode 100644 index 0000000000..6f40ffcabd --- /dev/null +++ b/robotoff/workers/queues.py @@ -0,0 +1,68 @@ +import enum +import threading +import time +from typing import Callable, Optional + +from rq import Queue +from rq.job import Job + +from robotoff.redis import redis_conn + + +class AvailableQueue(enum.Enum): + robotoff_high = "robotoff-high" + robotoff_low = "robotoff-low" + + +high_queue = Queue(AvailableQueue.robotoff_high.value, connection=redis_conn) +low_queue = Queue(AvailableQueue.robotoff_low.value, connection=redis_conn) + + +def enqueue_in_job( + func: Callable, + queue: Queue, + job_delay: float, + job_kwargs: Optional[dict] = None, + **kwargs +): + """Enqueue a job in `job_delay` seconds. + + Launch a new Thread where we sleep `job_delay` seconds and the job is + then enqueued. + + :param job_delay: number of seconds to sleep before sending the job to the + queue + """ + threading.Thread( + target=_enqueue_in_job, + args=(func, queue, job_delay, job_kwargs, kwargs), + ).start() + + +def _enqueue_in_job( + func: Callable, + queue: Queue, + job_delay: float, + job_kwargs: Optional[dict], + kwargs, +): + time.sleep(job_delay) + enqueue_job(func, queue, job_kwargs, **kwargs) + + +def enqueue_job( + func: Callable, queue: Queue, job_kwargs: Optional[dict] = None, **kwargs +): + """Create a new job from the function and kwargs and enqueue it in the + queue. + + The function will be called by one of the rq workers. For safety, only + keyword parameters can be provided to the function. + + :param func: the function to use + :param queue: the queue to use + :param job_kwargs: optional kwargs parameters to provide to `Job.create` + """ + job_kwargs = job_kwargs or {} + job = Job.create(func=func, kwargs=kwargs, connection=redis_conn, **job_kwargs) + return queue.enqueue_job(job=job) diff --git a/robotoff/workers/tasks/__init__.py b/robotoff/workers/tasks/__init__.py index 782f62fd5a..d5dce919b6 100644 --- a/robotoff/workers/tasks/__init__.py +++ b/robotoff/workers/tasks/__init__.py @@ -1,44 +1,29 @@ -import logging -import multiprocessing -from typing import Callable, Dict - -from robotoff.models import Prediction, ProductInsight, db, with_db +from robotoff.insights.importer import refresh_insights +from robotoff.models import Prediction, ProductInsight, with_db from robotoff.products import fetch_dataset, has_dataset_changed -from robotoff.utils import configure_root_logger, get_logger +from robotoff.utils import get_logger -from .import_image import run_import_image_job -from .product_updated import update_insights +from .import_image import run_import_image_job # noqa: F401 +from .product_updated import update_insights_job # noqa: F401 logger = get_logger(__name__) -root_logger = multiprocessing.get_logger() - -if root_logger.level == logging.NOTSET: - configure_root_logger(root_logger) - - -def run_task(event_type: str, event_kwargs: Dict) -> None: - if event_type not in EVENT_MAPPING: - raise ValueError(f"unknown event type: '{event_type}") - - func = EVENT_MAPPING[event_type] - - try: - # we run task inside transaction to avoid side effects - with db: - with db.atomic(): - func(**event_kwargs) - except Exception as e: - logger.error(e, exc_info=1) @with_db -def download_product_dataset(): +def download_product_dataset_job(): + """This job is triggered via /api/v1/products/dataset and causes Robotoff + to re-import the Product Opener product dump.""" if has_dataset_changed(): fetch_dataset() @with_db -def delete_product_insights(barcode: str, server_domain: str): +def delete_product_insights_job(barcode: str, server_domain: str): + """This job is triggered by Product Opener via /api/v1/webhook/product + when the given product has been removed from the database - in this case + we must delete all of the associated predictions and insights that have + not been annotated. + """ logger.info(f"Product {barcode} deleted, deleting associated insights...") deleted_predictions = ( Prediction.delete() @@ -64,28 +49,10 @@ def delete_product_insights(barcode: str, server_domain: str): ) -EVENT_MAPPING: Dict[str, Callable] = { - # 'import_image' is triggered every time there is a new OCR image available for processing by Robotoff, via /api/v1/images/import. - # - # On each image import, Robotoff performs the following tasks: - # 1. Generates various predictions based on the OCR-extracted text from the image. - # 2. Extracts the nutriscore prediction based on the nutriscore ML model. - # 3. Triggers the 'object_detection' task, which is described below. - # 4. Stores the imported image metadata in the Robotoff DB. - # - "import_image": run_import_image_job, - # 'download_dataset' is triggered via /api/v1/products/dataset and causes Robotoff to re-import the Product Opener product dump. - # - "download_dataset": download_product_dataset, - # 'product_deleted' is triggered by Product Opener via /api/v1/webhook/product when the given product has been removed from the - # database - in this case we must delete all of the associated predictions and insights that have not been annotated. - # - "product_deleted": delete_product_insights, - # 'product_updated' is similarly triggered by the webhook API, when product information has been updated. - # - # When a product is updated, Robotoff will: - # 1. Generate new predictions related to the product's category and name. - # 2. Regenerate all insights from the product associated predictions. - # - "product_updated": update_insights, -} +@with_db +def refresh_insights_job(barcodes: list[str], server_domain: str): + logger.info( + f"Refreshing insights for {len(barcodes)} products, server_domain: {server_domain}" + ) + for barcode in barcodes: + refresh_insights(barcode, server_domain) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 5806711825..bab8ddbbc8 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -3,25 +3,35 @@ from typing import Optional import requests -from PIL import Image -from robotoff import settings -from robotoff.insights.extraction import get_predictions_from_image +from robotoff.insights.extraction import ( + DEFAULT_OCR_PREDICTION_TYPES, + extract_ocr_predictions, + run_object_detection_model, +) from robotoff.insights.importer import import_insights from robotoff.logos import ( - LOGO_CONFIDENCE_THRESHOLDS, add_logos_to_ann, filter_logos, + get_logo_confidence_thresholds, import_logo_insights, save_nearest_neighbors, ) -from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation, db +from robotoff.models import ( + ImageModel, + ImagePrediction, + LogoAnnotation, + Prediction, + db, + with_db, +) from robotoff.off import get_server_type, get_source_from_url -from robotoff.prediction.object_detection import ObjectDetectionModelRegistry +from robotoff.prediction.object_detection import ObjectDetectionModel from robotoff.prediction.types import PredictionType from robotoff.products import Product, get_product_store from robotoff.slack import NotifierFactory from robotoff.utils import get_image_from_url, get_logger, http_session +from robotoff.workers.queues import enqueue_job, high_queue logger = get_logger(__name__) @@ -29,12 +39,23 @@ def run_import_image_job( barcode: str, image_url: str, ocr_url: str, server_domain: str ): + """This job is triggered every time there is a new OCR image available for + processing by Robotoff, via /api/v1/images/import. + + On each image import, Robotoff performs the following tasks: + + 1. Generates various predictions based on the OCR-extracted text from the image. + 2. Extracts the nutriscore prediction based on the nutriscore ML model. + 3. Triggers the 'object_detection' task + 4. Stores the imported image metadata in the Robotoff DB. + """ logger.info( f"Running `import_image` for product {barcode} ({server_domain}), image {image_url}" ) image = get_image_from_url(image_url, error_raise=False, session=http_session) if image is None: + logger.info(f"Error while downloading image {image_url}") return source_image = get_source_from_url(image_url) @@ -47,37 +68,95 @@ def run_import_image_job( return with db: - with db.atomic(): - save_image(barcode, source_image, product, server_domain) - import_insights_from_image( - barcode, image, source_image, ocr_url, server_domain - ) - with db.atomic(): - # Launch object detection in a new SQL transaction - run_object_detection(barcode, image, source_image, server_domain) + save_image(barcode, source_image, product, server_domain) + + enqueue_job( + import_insights_from_image, + high_queue, + barcode=barcode, + image_url=image_url, + ocr_url=ocr_url, + server_domain=server_domain, + ) + enqueue_job( + run_logo_object_detection, + high_queue, + barcode=barcode, + image_url=image_url, + server_domain=server_domain, + ) + enqueue_job( + run_nutrition_table_object_detection, + high_queue, + barcode=barcode, + image_url=image_url, + server_domain=server_domain, + ) def import_insights_from_image( barcode: str, - image: Image.Image, - source_image: str, + image_url: str, ocr_url: str, server_domain: str, ): - predictions_all = get_predictions_from_image(barcode, image, source_image, ocr_url) + image = get_image_from_url(image_url, error_raise=False, session=http_session) + + if image is None: + logger.info(f"Error while downloading image {image_url}") + return + + source_image = get_source_from_url(image_url) + predictions = extract_ocr_predictions( + barcode, ocr_url, DEFAULT_OCR_PREDICTION_TYPES + ) + if any( + prediction.value_tag == "en:nutriscore" + and prediction.type == PredictionType.label + for prediction in predictions + ): + enqueue_job( + run_nutriscore_object_detection, + high_queue, + barcode=barcode, + image_url=image_url, + server_domain=server_domain, + ) NotifierFactory.get_notifier().notify_image_flag( - [p for p in predictions_all if p.type == PredictionType.image_flag], + [p for p in predictions if p.type == PredictionType.image_flag], source_image, barcode, ) - imported = import_insights(predictions_all, server_domain) - logger.info(f"Import finished, {imported} insights imported") + + with db: + imported = import_insights(predictions, server_domain) + logger.info(f"Import finished, {imported} insights imported") + + +@with_db +def save_image_job(batch: list[tuple[str, str]], server_domain: str): + """Save a batch of images in DB. + + :param batch: a batch of (barcode, source_image) tuples + :param server_domain: the server domain to use + """ + for barcode, source_image in batch: + product = get_product_store()[barcode] + if product is None: + continue + save_image(barcode, source_image, product, server_domain) def save_image( barcode: str, source_image: str, product: Product, server_domain: str ) -> Optional[ImageModel]: """Save imported image details in DB.""" + if existing_image_model := ImageModel.get_or_none(source_image=source_image): + logger.info( + f"Image {source_image} already exist in DB, returning existing image", + ) + return existing_image_model + image_id = pathlib.Path(source_image).stem if not image_id.isdigit(): @@ -126,57 +205,155 @@ def save_image( return image_model -def run_object_detection( - barcode: str, image: Image.Image, source_image: str, server_domain: str +def run_nutrition_table_object_detection( + barcode: str, image_url: str, server_domain: str ): + logger.info( + f"Running nutrition table object detection for product {barcode} " + f"({server_domain}), image {image_url}" + ) + + image = get_image_from_url(image_url, error_raise=False, session=http_session) + + if image is None: + logger.info(f"Error while downloading image {image_url}") + return + + source_image = get_source_from_url(image_url) + + with db: + run_object_detection_model( + ObjectDetectionModel.nutrition_table, image, source_image + ) + + +NUTRISCORE_LABELS = { + "nutriscore-a": "en:nutriscore-grade-a", + "nutriscore-b": "en:nutriscore-grade-b", + "nutriscore-c": "en:nutriscore-grade-c", + "nutriscore-d": "en:nutriscore-grade-d", + "nutriscore-e": "en:nutriscore-grade-e", +} + + +def run_nutriscore_object_detection(barcode: str, image_url: str, server_domain: str): + logger.info( + f"Running nutriscore object detection for product {barcode} " + f"({server_domain}), image {image_url}" + ) + + image = get_image_from_url(image_url, error_raise=False, session=http_session) + + if image is None: + logger.info(f"Error while downloading image {image_url}") + return + + source_image = get_source_from_url(image_url) + + with db: + image_prediction = run_object_detection_model( + ObjectDetectionModel.nutriscore, image, source_image + ) + + if not image_prediction: + return + + results = [ + item for item in image_prediction.data["objects"] if item["score"] >= 0.5 + ] + + if len(results) > 1: + logger.info("more than one nutriscore detected, discarding detections") + return + + result = results[0] + score = result["score"] + label_tag = NUTRISCORE_LABELS[result["label"]] + + prediction = Prediction( + type=PredictionType.label, + barcode=barcode, + source_image=source_image, + value_tag=label_tag, + automatic_processing=False, + server_domain=server_domain, + data={ + "confidence": score, + "bounding_box": result["bounding_box"], + "model": ObjectDetectionModel.nutriscore.value, + }, + ) + import_insights([prediction], server_domain) + + +def run_logo_object_detection(barcode: str, image_url: str, server_domain: str): """Detect logos using the universal logo detector model and generate - logo-related insights. + logo-related predictions. :param barcode: Product barcode - :param image: Pillow Image to run the object detection on :param image_url: URL of the image to use :param server_domain: The server domain associated with the image """ logger.info( - f"Running object detection for product {barcode} ({server_domain}), " - f"image {source_image}" + f"Running logo object detection for product {barcode} " + f"({server_domain}), image {image_url}" ) - image_instance = ImageModel.get_or_none(source_image=source_image) - if image_instance is None: - logger.warning("Missing image in DB for image %s", source_image) + image = get_image_from_url(image_url, error_raise=False, session=http_session) + + if image is None: + logger.info(f"Error while downloading image {image_url}") return - timestamp = datetime.datetime.utcnow() - model_name = "universal-logo-detector" - results = ObjectDetectionModelRegistry.get(model_name).detect_from_image( - image, output_image=False - ) - data = results.to_json(threshold=0.1) - max_confidence = max([item["score"] for item in data], default=None) - image_prediction = ImagePrediction.create( - image=image_instance, - type="object_detection", - model_name=model_name, - model_version=settings.OBJECT_DETECTION_MODEL_VERSION[model_name], - data={"objects": data}, - timestamp=timestamp, - max_confidence=max_confidence, - ) + source_image = get_source_from_url(image_url) + + with db: + image_prediction = run_object_detection_model( + ObjectDetectionModel.universal_logo_detector, image, source_image + ) - logos = [] - for i, item in filter_logos(data, score_threshold=0.5, iou_threshold=0.95): - logos.append( - LogoAnnotation.create( - image_prediction=image_prediction, - index=i, - score=item["score"], - bounding_box=item["bounding_box"], + if image_prediction is None: + # Can occur in normal conditions if an image prediction + # already exists for this image and model + return + + logo_ids = [] + for i, item in filter_logos( + image_prediction.data["objects"], + score_threshold=0.5, + iou_threshold=0.95, + ): + logo_ids.append( + LogoAnnotation.create( + image_prediction=image_prediction, + index=i, + score=item["score"], + bounding_box=item["bounding_box"], + ).id ) + + logger.info(f"{len(logo_ids)} logos found for image {source_image}") + if logo_ids: + enqueue_job( + process_created_logos, + high_queue, + job_kwargs={}, + image_prediction_id=image_prediction.id, + server_domain=server_domain, ) - logger.info(f"{len(logos)} logos found for image {source_image}") + +@with_db +def process_created_logos(image_prediction_id: int, server_domain: str): + logos = ( + LogoAnnotation.select() + .join(ImagePrediction) + .join(ImageModel) + .where(ImagePrediction.id == image_prediction_id) + ) + if logos: + image_instance = logos[0].image_prediction.image add_logos_to_ann(image_instance, logos) try: @@ -188,5 +365,5 @@ def run_object_detection( resp.text, ) - thresholds = LOGO_CONFIDENCE_THRESHOLDS.get() + thresholds = get_logo_confidence_thresholds() import_logo_insights(logos, thresholds=thresholds, server_domain=server_domain) diff --git a/robotoff/workers/tasks/product_updated.py b/robotoff/workers/tasks/product_updated.py index ce4c013672..0ef994b6eb 100644 --- a/robotoff/workers/tasks/product_updated.py +++ b/robotoff/workers/tasks/product_updated.py @@ -7,6 +7,7 @@ from robotoff.prediction.category.matcher import predict as predict_category_matcher from robotoff.prediction.category.neural.category_classifier import CategoryClassifier from robotoff.products import get_product +from robotoff.redis import Lock, LockedResourceException from robotoff.taxonomy import TaxonomyType, get_taxonomy from robotoff.utils import get_logger from robotoff.utils.types import JSONType @@ -15,21 +16,35 @@ @with_db -def update_insights(barcode: str, server_domain: str): - # Sleep 10s to let the OFF update request that triggered the webhook call - # to finish - logger.info(f"Running `update_insights` for product {barcode} ({server_domain})") +def update_insights_job(barcode: str, server_domain: str): + """This job is triggered by the webhook API, when product information has + been updated. - product_dict = get_product(barcode) + When a product is updated, Robotoff will: - if product_dict is None: - logger.warning("Updated product does not exist: %s", barcode) - return + 1. Generate new predictions related to the product's category and name. + 2. Regenerate all insights from the product associated predictions. + """ + logger.info(f"Running `update_insights` for product {barcode} ({server_domain})") - updated_product_predict_insights(barcode, product_dict, server_domain) - logger.info("Refreshing insights...") - imported = refresh_insights(barcode, server_domain) - logger.info(f"{imported} insights created after refresh") + try: + with Lock( + name=f"robotoff:product_update_job:{barcode}", expire=300, timeout=10 + ): + product_dict = get_product(barcode) + + if product_dict is None: + logger.warning("Updated product does not exist: %s", barcode) + return + + updated_product_predict_insights(barcode, product_dict, server_domain) + logger.info("Refreshing insights...") + imported = refresh_insights(barcode, server_domain) + logger.info(f"{imported} insights created after refresh") + except LockedResourceException: + logger.info( + f"Couldn't acquire product_update lock, skipping product_update for product {barcode}" + ) def add_category_insight(barcode: str, product: JSONType, server_domain: str) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 622c3ca487..f2914ae3c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,15 @@ import pytest from robotoff import models +from robotoff.redis import Lock + + +@pytest.fixture(scope="session", autouse=True) +def disable_redis_lock(): + previous_value = Lock._enabled + Lock._enabled = False + yield + Lock._enabled = previous_value @pytest.fixture(scope="session") diff --git a/tests/integration/insights/test_annotate.py b/tests/integration/insights/test_annotate.py index 6090f1f1f6..f7064b3a5f 100644 --- a/tests/integration/insights/test_annotate.py +++ b/tests/integration/insights/test_annotate.py @@ -8,10 +8,12 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - clean_db() - # Run the test case. - yield - clean_db() + with peewee_db: + # clean db + clean_db() + # Run the test case. + yield + clean_db() def test_annotation_fails_is_rolledback(mocker): diff --git a/tests/integration/insights/test_category_import.py b/tests/integration/insights/test_category_import.py index d4ff1d2379..3737e80b87 100644 --- a/tests/integration/insights/test_category_import.py +++ b/tests/integration/insights/test_category_import.py @@ -14,27 +14,28 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - # clean db - clean_db() - # a category already exists - PredictionFactory( - barcode=barcode1, - type="category", - value_tag="en:salmons", - automatic_processing=False, - predictor="matcher", - ) - ProductInsightFactory( - id=insight_id1, - barcode=barcode1, - type="category", - value_tag="en:salmons", - predictor="matcher", - ) - # Run the test case. - yield - # Tear down. - clean_db() + with peewee_db: + # clean db + clean_db() + # a category already exists + PredictionFactory( + barcode=barcode1, + type="category", + value_tag="en:salmons", + automatic_processing=False, + predictor="matcher", + ) + ProductInsightFactory( + id=insight_id1, + barcode=barcode1, + type="category", + value_tag="en:salmons", + predictor="matcher", + ) + # Run the test case. + yield + # Tear down. + clean_db() def matcher_prediction(category): diff --git a/tests/integration/insights/test_extraction.py b/tests/integration/insights/test_extraction.py new file mode 100644 index 0000000000..ee90f0304e --- /dev/null +++ b/tests/integration/insights/test_extraction.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest +from PIL import Image + +from robotoff.insights.extraction import run_object_detection_model +from robotoff.models import ImagePrediction +from robotoff.prediction.object_detection.core import ( + ObjectDetectionModel, + ObjectDetectionRawResult, + RemoteModel, +) + +from ..models_utils import ImageModelFactory, clean_db + + +@pytest.fixture() +def image_model(peewee_db): + with peewee_db: + clean_db() + yield ImageModelFactory(source_image="/1/1.jpg") + clean_db() + + +class FakeNutriscoreModel(RemoteModel): + def __init__(self, raw_result: ObjectDetectionRawResult): + self.raw_result = raw_result + + def detect_from_image( + self, image: Image.Image, output_image: bool = False + ) -> ObjectDetectionRawResult: + return self.raw_result + + +@pytest.mark.parametrize( + "model_name,label_names", + [ + (ObjectDetectionModel.universal_logo_detector, ["brand", "label"]), + ( + ObjectDetectionModel.nutriscore, + [ + "nutriscore-a", + "nutriscore-b", + "nutriscore-d", + "nutriscore-d", + "nutriscore-e", + ], + ), + ], +) +def test_run_object_detection_model(mocker, image_model, model_name, label_names): + raw_result = ObjectDetectionRawResult( + num_detections=1, + detection_boxes=np.array([[1, 2, 3, 4]]), + detection_scores=np.array([0.8]), + detection_classes=np.array([1]), + label_names=label_names, + ) + mocker.patch( + "robotoff.prediction.object_detection.core.ObjectDetectionModelRegistry.get", + return_value=FakeNutriscoreModel(raw_result), + ) + image_prediction = run_object_detection_model( + model_name, + None, + source_image=image_model.source_image, + threshold=0.1, + ) + assert isinstance(image_prediction, ImagePrediction) + assert image_prediction.type == "object_detection" + assert image_prediction.model_name == model_name.value + assert image_prediction.data == { + "objects": [ + {"bounding_box": (1, 2, 3, 4), "score": 0.8, "label": label_names[1]} + ] + } + assert image_prediction.max_confidence == 0.8 + + +def test_run_object_detection_model_no_image_instance(peewee_db): + with peewee_db: + image_prediction = run_object_detection_model( + ObjectDetectionModel.nutriscore, + None, + source_image="/images/1/1.jpg", + threshold=0.1, + ) + assert image_prediction is None diff --git a/tests/integration/insights/test_process_insights.py b/tests/integration/insights/test_process_insights.py index 48313ed65b..e923faef4d 100644 --- a/tests/integration/insights/test_process_insights.py +++ b/tests/integration/insights/test_process_insights.py @@ -11,12 +11,13 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - # clean db - clean_db() - # Run the test case. - yield - # Tear down. - clean_db() + with peewee_db: + # clean db + clean_db() + # Run the test case. + yield + # Tear down. + clean_db() # global for generating items diff --git a/tests/integration/test_annotate_image.py b/tests/integration/test_annotate_image.py index 35e1dea737..2182eb9c25 100644 --- a/tests/integration/test_annotate_image.py +++ b/tests/integration/test_annotate_image.py @@ -22,10 +22,13 @@ def client(): @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - clean_db() - # Run the test case. + with peewee_db: + clean_db() + # Run the test case. yield - clean_db() + + with peewee_db: + clean_db() def _fake_store(monkeypatch, barcode): @@ -126,11 +129,14 @@ def test_logo_annotation_missing_value_when_required(logo_type, client): } -def test_logo_annotation_incorrect_value_label_type(client): +def test_logo_annotation_incorrect_value_label_type(client, peewee_db): """A language-prefixed value is expected for label type.""" - ann = LogoAnnotationFactory( - image_prediction__image__source_image="/images/2.jpg", annotation_type="label" - ) + + with peewee_db: + ann = LogoAnnotationFactory( + image_prediction__image__source_image="/images/2.jpg", + annotation_type="label", + ) result = client.simulate_post( "/api/v1/images/logos/annotate", json={ @@ -148,10 +154,12 @@ def test_logo_annotation_incorrect_value_label_type(client): } -def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy): - ann = LogoAnnotationFactory( - image_prediction__image__source_image="/images/2.jpg", annotation_type="brand" - ) +def test_logo_annotation_brand(client, peewee_db, monkeypatch, fake_taxonomy): + with peewee_db: + ann = LogoAnnotationFactory( + image_prediction__image__source_image="/images/2.jpg", + annotation_type="brand", + ) barcode = ann.image_prediction.image.barcode _fake_store(monkeypatch, barcode) monkeypatch.setattr( @@ -169,7 +177,9 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy): end = datetime.utcnow() assert result.status_code == 200 assert result.json == {"created insights": 1} - ann = LogoAnnotation.get(LogoAnnotation.id == ann.id) + + with peewee_db: + ann = LogoAnnotation.get(LogoAnnotation.id == ann.id) assert ann.annotation_type == "brand" assert ann.annotation_value == "etorki" assert ann.annotation_value_tag == "etorki" @@ -177,7 +187,9 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy): assert ann.username == "a" assert start <= ann.completed_at <= end # we generate a prediction - predictions = list(Prediction.select().filter(barcode=barcode).execute()) + + with peewee_db: + predictions = list(Prediction.select().filter(barcode=barcode).execute()) assert len(predictions) == 1 (prediction,) = predictions assert prediction.type == "brand" @@ -195,7 +207,9 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy): assert start <= prediction.timestamp <= end assert prediction.automatic_processing # We check that this prediction in turn generates an insight - insights = list(ProductInsight.select().filter(barcode=barcode).execute()) + + with peewee_db: + insights = list(ProductInsight.select().filter(barcode=barcode).execute()) assert len(insights) == 1 (insight,) = insights assert insight.type == "brand" @@ -216,11 +230,14 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy): assert insight.completed_at is None # we did not run annotate yet -def test_logo_annotation_label(client, monkeypatch, fake_taxonomy): +def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy): """This test will check that, given an image with a logo above the confidence threshold, that is then fed into the ANN logos and labels model, we annotate properly a product. """ - ann = LogoAnnotationFactory(image_prediction__image__source_image="/images/2.jpg") + with peewee_db: + ann = LogoAnnotationFactory( + image_prediction__image__source_image="/images/2.jpg" + ) barcode = ann.image_prediction.image.barcode _fake_store(monkeypatch, barcode) start = datetime.utcnow() @@ -237,7 +254,8 @@ def test_logo_annotation_label(client, monkeypatch, fake_taxonomy): end = datetime.utcnow() assert result.status_code == 200 assert result.json == {"created insights": 1} - ann = LogoAnnotation.get(LogoAnnotation.id == ann.id) + with peewee_db: + ann = LogoAnnotation.get(LogoAnnotation.id == ann.id) assert ann.annotation_type == "label" assert ann.annotation_value == "en:eu-organic" assert ann.annotation_value_tag == "en:eu-organic" @@ -245,7 +263,8 @@ def test_logo_annotation_label(client, monkeypatch, fake_taxonomy): assert ann.username == "a" assert start <= ann.completed_at <= end # we generate a prediction - predictions = list(Prediction.select().filter(barcode=barcode).execute()) + with peewee_db: + predictions = list(Prediction.select().filter(barcode=barcode).execute()) assert len(predictions) == 1 (prediction,) = predictions assert prediction.type == "label" @@ -263,7 +282,8 @@ def test_logo_annotation_label(client, monkeypatch, fake_taxonomy): assert start <= prediction.timestamp <= end assert prediction.automatic_processing # We check that this prediction in turn generates an insight - insights = list(ProductInsight.select().filter(barcode=barcode).execute()) + with peewee_db: + insights = list(ProductInsight.select().filter(barcode=barcode).execute()) assert len(insights) == 1 (insight,) = insights assert insight.type == "label" diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index fea4648bb8..a9be17cd85 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -26,13 +26,15 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - # clean db - clean_db() - # Set up. - ProductInsightFactory(id=insight_id, barcode=1) + with peewee_db: + # clean db + clean_db() + # Set up. + ProductInsightFactory(id=insight_id, barcode=1) # Run the test case. yield - clean_db() + with peewee_db: + clean_db() @pytest.fixture() @@ -75,12 +77,13 @@ def test_random_question(client, mocker): } -def test_random_question_user_has_already_seen(client, mocker): +def test_random_question_user_has_already_seen(client, mocker, peewee_db): mocker.patch("robotoff.insights.question.get_product", return_value={}) - AnnotationVoteFactory( - insight_id=insight_id, - device_id="device1", - ) + with peewee_db: + AnnotationVoteFactory( + insight_id=insight_id, + device_id="device1", + ) result = client.simulate_get("/api/v1/questions/random?device_id=device1") @@ -110,11 +113,13 @@ def test_popular_question(client, mocker): } -def test_popular_question_pagination(client, mocker): +def test_popular_question_pagination(client, mocker, peewee_db): mocker.patch("robotoff.insights.question.get_product", return_value={}) - ProductInsight.delete().execute() # remove default sample - for i in range(0, 12): - ProductInsightFactory(barcode=i, unique_scans_n=100 - i) + + with peewee_db: + ProductInsight.delete().execute() # remove default sample + for i in range(0, 12): + ProductInsightFactory(barcode=i, unique_scans_n=100 - i) result = client.simulate_get("/api/v1/questions/popular?count=5&page=1") assert result.status_code == 200 @@ -170,7 +175,7 @@ def test_barcode_question(client, mocker): } -def test_annotate_insight_authenticated(client): +def test_annotate_insight_authenticated(client, peewee_db): result = client.simulate_post( "/api/v1/insights/annotate", params={ @@ -188,23 +193,26 @@ def test_annotate_insight_authenticated(client): } # For authenticated users we expect the insight to be validated directly, tracking the username of the annotator. - votes = list(AnnotationVote.select()) - assert len(votes) == 0 - - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) - assert insight.items() > {"username": "a", "annotation": 0, "n_votes": 0}.items() - assert "completed_at" in insight + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 0 + + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) + assert ( + insight.items() > {"username": "a", "annotation": 0, "n_votes": 0}.items() + ) + assert "completed_at" in insight # check if "annotated_result" is saved assert insight["annotated_result"] == 1 -def test_annotate_insight_authenticated_ignore(client): +def test_annotate_insight_authenticated_ignore(client, peewee_db): result = client.simulate_post( "/api/v1/insights/annotate", params={ @@ -221,21 +229,23 @@ def test_annotate_insight_authenticated_ignore(client): "description": "the annotation vote was saved", } - votes = list(AnnotationVote.select()) - assert len(votes) == 1 + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 1 - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) - assert ( - insight.items() > {"username": None, "annotation": None, "n_votes": 0}.items() - ) + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) + assert ( + insight.items() + > {"username": None, "annotation": None, "n_votes": 0}.items() + ) -def test_annotate_insight_not_enough_votes(client): +def test_annotate_insight_not_enough_votes(client, peewee_db): result = client.simulate_post( "/api/v1/insights/annotate", params={ @@ -253,46 +263,48 @@ def test_annotate_insight_not_enough_votes(client): } # For non-authenticated users we expect the insight to not be validated, with only a vote being cast. - votes = list(AnnotationVote.select().dicts()) + with peewee_db: + votes = list(AnnotationVote.select().dicts()) assert len(votes) == 1 assert votes[0]["value"] == 1 assert votes[0]["username"] is None assert votes[0]["device_id"] == "voter1" - - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + with peewee_db: + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) assert not any(insight[key] for key in ("username", "completed_at", "annotation")) assert insight.items() > {"n_votes": 1}.items() -def test_annotate_insight_majority_annotation(client): +def test_annotate_insight_majority_annotation(client, peewee_db): # Add pre-existing insight votes. - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter1", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter2", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=0, - device_id="no-voter1", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=-1, - device_id="ignore-voter1", - ) + with peewee_db: + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter1", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter2", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=0, + device_id="no-voter1", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=-1, + device_id="ignore-voter1", + ) result = client.simulate_post( "/api/v1/insights/annotate", @@ -311,38 +323,40 @@ def test_annotate_insight_majority_annotation(client): "description": "the annotation was saved", } - votes = list(AnnotationVote.select()) - assert len(votes) == 5 + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 5 - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) # The insight should be annoted with '1', with a None username since this was resolved with an # anonymous vote. `n_votes = 4, as -1 votes are not considered assert insight.items() > {"annotation": 1, "username": None, "n_votes": 4}.items() # This test checks for handling of cases where we have 2 votes for 2 different annotations. -def test_annotate_insight_opposite_votes(client): +def test_annotate_insight_opposite_votes(client, peewee_db): # Add pre-existing insight votes. - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter1", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter2", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=0, - device_id="no-voter1", - ) + with peewee_db: + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter1", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter2", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=0, + device_id="no-voter1", + ) result = client.simulate_post( "/api/v1/insights/annotate", @@ -361,15 +375,16 @@ def test_annotate_insight_opposite_votes(client): "description": "the annotation was saved", } - votes = list(AnnotationVote.select()) - assert len(votes) == 4 + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 4 - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) # The insight should be annoted with '-1', with a None username since this was resolved with an # anonymous vote. assert insight.items() > {"annotation": -1, "username": None, "n_votes": 4}.items() @@ -377,28 +392,29 @@ def test_annotate_insight_opposite_votes(client): # This test checks for handling of cases where we have 3 votes for one annotation, # but the follow-up has 2 votes. -def test_annotate_insight_majority_vote_overridden(client): +def test_annotate_insight_majority_vote_overridden(client, peewee_db): # Add pre-existing insight votes. - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter1", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=1, - device_id="yes-voter2", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=0, - device_id="no-voter1", - ) - AnnotationVoteFactory( - insight_id=insight_id, - value=0, - device_id="no-voter2", - ) + with peewee_db: + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter1", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=1, + device_id="yes-voter2", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=0, + device_id="no-voter1", + ) + AnnotationVoteFactory( + insight_id=insight_id, + value=0, + device_id="no-voter2", + ) result = client.simulate_post( "/api/v1/insights/annotate", @@ -417,21 +433,22 @@ def test_annotate_insight_majority_vote_overridden(client): "description": "the annotation was saved", } - votes = list(AnnotationVote.select()) - assert len(votes) == 5 + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 5 - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) # The insight should be annoted with '0', with a None username since this was resolved with an # anonymous vote. assert insight.items() > {"annotation": -1, "username": None, "n_votes": 5}.items() -def test_annotate_insight_anonymous_then_authenticated(client, mocker): +def test_annotate_insight_anonymous_then_authenticated(client, mocker, peewee_db): """Test that annotating first as anonymous, then, just after, as authenticated validate the anotation""" # mock because as we validate the insight, we will ask mongo for product @@ -458,17 +475,18 @@ def test_annotate_insight_anonymous_then_authenticated(client, mocker): } # For non-authenticated users we expect the insight to not be validated, with only a vote being cast. - votes = list(AnnotationVote.select()) - assert len(votes) == 1 - # no category added - add_category.assert_not_called() - - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 1 + # no category added + add_category.assert_not_called() + + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) assert not any( insight[key] @@ -496,15 +514,16 @@ def test_annotate_insight_anonymous_then_authenticated(client, mocker): "status_code": 2, } # We have the previous vote, but the last request should validate the insight directly - votes = list(AnnotationVote.select()) - assert len(votes) == 1 # this is the previous vote - - insight = next( - ProductInsight.select() - .where(ProductInsight.id == insight_id) - .dicts() - .iterator() - ) + with peewee_db: + votes = list(AnnotationVote.select()) + assert len(votes) == 1 # this is the previous vote + + insight = next( + ProductInsight.select() + .where(ProductInsight.id == insight_id) + .dicts() + .iterator() + ) # we still have the vote, but we also have an authenticated validation assert insight.items() > {"username": "a", "n_votes": 1, "annotation": 1}.items() assert insight.get("completed_at") is not None @@ -528,9 +547,10 @@ def test_image_collection_no_result(client): assert data["status"] == "no_images" -def test_image_collection(client): - image_model = ImageModelFactory(barcode="123") - ImagePredictionFactory(image__barcode="456") +def test_image_collection(client, peewee_db): + with peewee_db: + image_model = ImageModelFactory(barcode="123") + ImagePredictionFactory(image__barcode="456") result = client.simulate_get( "/api/v1/images", @@ -611,9 +631,9 @@ def test_prediction_collection_no_result(client): assert result.json == {"count": 0, "predictions": [], "status": "no_predictions"} -def test_prediction_collection_no_filter(client): - - prediction1 = PredictionFactory(value_tag="en:seeds") +def test_prediction_collection_no_filter(client, peewee_db): + with peewee_db: + prediction1 = PredictionFactory(value_tag="en:seeds") result = client.simulate_get("/api/v1/predictions") assert result.status_code == 200 data = result.json @@ -624,9 +644,10 @@ def test_prediction_collection_no_filter(client): assert prediction_data[0]["type"] == "category" assert prediction_data[0]["value_tag"] == "en:seeds" - prediction2 = PredictionFactory( - value_tag="en:beers", data={"sample": 1}, type="brand" - ) + with peewee_db: + prediction2 = PredictionFactory( + value_tag="en:beers", data={"sample": 1}, type="brand" + ) result = client.simulate_get("/api/v1/predictions") assert result.status_code == 200 data = result.json @@ -641,47 +662,37 @@ def test_prediction_collection_no_filter(client): assert prediction_data[1]["value_tag"] == "en:beers" -def test_get_unanswered_questions_api_empty(client): - ProductInsight.delete().execute() # remove default sample +def test_get_unanswered_questions_api_empty(client, peewee_db): + with peewee_db: + ProductInsight.delete().execute() # remove default sample result = client.simulate_get("/api/v1/questions/unanswered") assert result.status_code == 200 assert result.json == {"count": 0, "questions": [], "status": "no_questions"} -def test_get_unanswered_questions_api(client): - ProductInsight.delete().execute() # remove default sample - - ProductInsightFactory(type="category", value_tag="en:apricot", barcode="123") - - ProductInsightFactory(type="label", value_tag="en:beer", barcode="456") - - ProductInsightFactory(type="nutrition", value_tag="en:soups", barcode="789") - - ProductInsightFactory(type="nutrition", value_tag="en:salad", barcode="302") - - ProductInsightFactory(type="nutrition", value_tag="en:salad", barcode="403") - - ProductInsightFactory(type="category", value_tag="en:soups", barcode="194") - - ProductInsightFactory(type="category", value_tag="en:soups", barcode="967") - - ProductInsightFactory(type="label", value_tag="en:beer", barcode="039") - - ProductInsightFactory(type="category", value_tag="en:apricot", barcode="492") - - ProductInsightFactory(type="category", value_tag="en:soups", barcode="594") - - ProductInsightFactory( - type="category", - value_tag="en:apricot", - barcode="780", - annotation=1, - ) - - ProductInsightFactory( - type="category", value_tag="en:apricot", barcode="983", annotation=0 - ) +def test_get_unanswered_questions_api(client, peewee_db): + with peewee_db: + ProductInsight.delete().execute() # remove default sample + ProductInsightFactory(type="category", value_tag="en:apricot", barcode="123") + ProductInsightFactory(type="label", value_tag="en:beer", barcode="456") + ProductInsightFactory(type="nutrition", value_tag="en:soups", barcode="789") + ProductInsightFactory(type="nutrition", value_tag="en:salad", barcode="302") + ProductInsightFactory(type="nutrition", value_tag="en:salad", barcode="403") + ProductInsightFactory(type="category", value_tag="en:soups", barcode="194") + ProductInsightFactory(type="category", value_tag="en:soups", barcode="967") + ProductInsightFactory(type="label", value_tag="en:beer", barcode="039") + ProductInsightFactory(type="category", value_tag="en:apricot", barcode="492") + ProductInsightFactory(type="category", value_tag="en:soups", barcode="594") + ProductInsightFactory( + type="category", + value_tag="en:apricot", + barcode="780", + annotation=1, + ) + ProductInsightFactory( + type="category", value_tag="en:apricot", barcode="983", annotation=0 + ) # test to get all "category" with "annotation=None" @@ -724,17 +735,19 @@ def test_get_unanswered_questions_api(client): assert data["status"] == "found" -def test_get_unanswered_questions_api_with_country_filter(client): - ProductInsight.delete().execute() # remove default sample - - # test for filter with "country" - - ProductInsightFactory( - type="location", value_tag="en:dates", barcode="032", countries=["en:india"] - ) - ProductInsightFactory( - type="location", value_tag="en:dates", barcode="033", countries=["en:france"] - ) +def test_get_unanswered_questions_api_with_country_filter(client, peewee_db): + with peewee_db: + ProductInsight.delete().execute() # remove default sample + # test for filter with "country" + ProductInsightFactory( + type="location", value_tag="en:dates", barcode="032", countries=["en:india"] + ) + ProductInsightFactory( + type="location", + value_tag="en:dates", + barcode="033", + countries=["en:france"], + ) result = client.simulate_get( "/api/v1/questions/unanswered", params={"country": "en:india"} @@ -747,10 +760,11 @@ def test_get_unanswered_questions_api_with_country_filter(client): assert data["status"] == "found" -def test_get_unanswered_questions_pagination(client): - ProductInsight.delete().execute() # remove default sample - for i in range(0, 12): - ProductInsightFactory(type="nutrition", value_tag=f"en:soups-{i:02}") +def test_get_unanswered_questions_pagination(client, peewee_db): + with peewee_db: + ProductInsight.delete().execute() # remove default sample + for i in range(0, 12): + ProductInsightFactory(type="nutrition", value_tag=f"en:soups-{i:02}") result = client.simulate_get( "/api/v1/questions/unanswered?count=5&page=1&type=nutrition" @@ -804,22 +818,22 @@ def test_image_prediction_collection_empty(client): assert result.status_code == 200 -def test_image_prediction_collection(client): - - logo_annotation_category_123 = LogoAnnotationFactory( - image_prediction__image__barcode="123", - image_prediction__type="category", - ) - prediction_category_123 = logo_annotation_category_123.image_prediction - logo_annotation_label_789 = LogoAnnotationFactory( - image_prediction__image__barcode="789", - image_prediction__type="label", - ) - prediction_label_789 = logo_annotation_label_789.image_prediction +def test_image_prediction_collection(client, peewee_db): + with peewee_db: + logo_annotation_category_123 = LogoAnnotationFactory( + image_prediction__image__barcode="123", + image_prediction__type="category", + ) + prediction_category_123 = logo_annotation_category_123.image_prediction + logo_annotation_label_789 = LogoAnnotationFactory( + image_prediction__image__barcode="789", + image_prediction__type="label", + ) + prediction_label_789 = logo_annotation_label_789.image_prediction - prediction_label_789_no_logo = ImagePredictionFactory( - image__barcode="789", type="label" - ) + prediction_label_789_no_logo = ImagePredictionFactory( + image__barcode="789", type="label" + ) # test with "barcode=123" and "with_logo=True" result = client.simulate_get( @@ -886,44 +900,39 @@ def test_logo_annotation_collection_empty(client): assert result.json == {"count": 0, "annotation": [], "status": "no_annotation"} -def test_logo_annotation_collection_api(client): - LogoAnnotation.delete().execute() # remove default sample - - annotation_123_1 = LogoAnnotationFactory( - image_prediction__image__barcode="123", - annotation_value_tag="etorki", - annotation_type="brand", - ) - - annotation_123_2 = LogoAnnotationFactory( - image_prediction__image__barcode="123", - annotation_value_tag="etorki", - annotation_type="brand", - ) - - annotation_295 = LogoAnnotationFactory( - image_prediction__image__barcode="295", - annotation_value_tag="cheese", - annotation_type="dairies", - ) - - annotation_789 = LogoAnnotationFactory( - image_prediction__image__barcode="789", - annotation_value_tag="creme", - annotation_type="dairies", - ) - - annotation_306 = LogoAnnotationFactory( - image_prediction__image__barcode="306", - annotation_value_tag="yoghurt", - annotation_type="dairies", - ) - - annotation_604 = LogoAnnotationFactory( - image_prediction__image__barcode="604", - annotation_value_tag="meat", - annotation_type="category", - ) +def test_logo_annotation_collection_api(client, peewee_db): + with peewee_db: + LogoAnnotation.delete().execute() # remove default sample + annotation_123_1 = LogoAnnotationFactory( + image_prediction__image__barcode="123", + annotation_value_tag="etorki", + annotation_type="brand", + ) + annotation_123_2 = LogoAnnotationFactory( + image_prediction__image__barcode="123", + annotation_value_tag="etorki", + annotation_type="brand", + ) + annotation_295 = LogoAnnotationFactory( + image_prediction__image__barcode="295", + annotation_value_tag="cheese", + annotation_type="dairies", + ) + annotation_789 = LogoAnnotationFactory( + image_prediction__image__barcode="789", + annotation_value_tag="creme", + annotation_type="dairies", + ) + annotation_306 = LogoAnnotationFactory( + image_prediction__image__barcode="306", + annotation_value_tag="yoghurt", + annotation_type="dairies", + ) + annotation_604 = LogoAnnotationFactory( + image_prediction__image__barcode="604", + annotation_value_tag="meat", + annotation_type="category", + ) # test with "barcode" @@ -982,22 +991,23 @@ def test_logo_annotation_collection_api(client): assert annotations[3]["image_prediction"]["image"]["barcode"] == "604" -def test_logo_annotation_collection_pagination(client): - LogoAnnotation.delete().execute() # remove default sample - for i in range(0, 12): - LogoAnnotationFactory( - annotation_type="label", annotation_value_tag=f"no lactose-{i:02}" - ) +def test_logo_annotation_collection_pagination(client, peewee_db): + with peewee_db: + LogoAnnotation.delete().execute() # remove default sample + for i in range(0, 12): + LogoAnnotationFactory( + annotation_type="label", annotation_value_tag=f"no lactose-{i:02}" + ) - for i in range(0, 2): - LogoAnnotationFactory( - annotation_type="vegan", annotation_value_tag=f"truffle cake-{i:02}" - ) + for i in range(0, 2): + LogoAnnotationFactory( + annotation_type="vegan", annotation_value_tag=f"truffle cake-{i:02}" + ) - for i in range(0, 2): - LogoAnnotationFactory( - annotation_type="category", annotation_value_tag=f"sea food-{i:02}" - ) + for i in range(0, 2): + LogoAnnotationFactory( + annotation_type="category", annotation_value_tag=f"sea food-{i:02}" + ) result = client.simulate_get( "/api/v1/annotation/collection?count=5&page=1&types=label" diff --git a/tests/integration/test_core_integration.py b/tests/integration/test_core_integration.py index 5e6aa69364..bf830ce28d 100644 --- a/tests/integration/test_core_integration.py +++ b/tests/integration/test_core_integration.py @@ -20,11 +20,12 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - # clean db - clean_db() - # Run the test case. - yield - clean_db() + with peewee_db: + # clean db + clean_db() + # Run the test case. + yield + clean_db() def prediction_ids(data): diff --git a/tests/integration/test_models_integration.py b/tests/integration/test_models_integration.py index 8b29fee12b..69f577b720 100644 --- a/tests/integration/test_models_integration.py +++ b/tests/integration/test_models_integration.py @@ -7,12 +7,12 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - # clean db - clean_db() - # Run the test case. - yield - # Tear down. - clean_db() + with peewee_db: + # clean db + clean_db() + # Run the test case. + yield + clean_db() def test_vote_cascade_on_insight_deletion(peewee_db): diff --git a/tests/integration/test_scheduler.py b/tests/integration/test_scheduler.py index 92caa1aaff..44911fffc4 100644 --- a/tests/integration/test_scheduler.py +++ b/tests/integration/test_scheduler.py @@ -10,10 +10,12 @@ @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): - clean_db() - # Run the test case. - yield - clean_db() + with peewee_db: + # clean db + clean_db() + # Run the test case. + yield + clean_db() def test_mark_insights(): diff --git a/tests/unit/insights/test_extraction.py b/tests/unit/insights/test_extraction.py deleted file mode 100644 index 94a3bee9e9..0000000000 --- a/tests/unit/insights/test_extraction.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -import pytest -from PIL import Image - -from robotoff.insights.extraction import extract_nutriscore_label -from robotoff.prediction.object_detection.core import ( - ObjectDetectionRawResult, - RemoteModel, -) -from robotoff.prediction.types import Prediction, PredictionType - - -class FakeNutriscoreModel(RemoteModel): - def __init__(self, raw_result: ObjectDetectionRawResult): - self.raw_result = raw_result - - def detect_from_image( - self, image: Image.Image, output_image: bool = False - ) -> ObjectDetectionRawResult: - return self.raw_result - - -@pytest.mark.parametrize( - "automatic_threshold, processed_automatically, source_image", - [ - (None, False, "/image/1"), - (0.7, True, "/image/1"), - ], -) -def test_extract_nutriscore_label_automatic( - mocker, source_image, automatic_threshold, processed_automatically -): - raw_result = ObjectDetectionRawResult( - num_detections=1, - detection_boxes=np.array([[1, 2, 3, 4]]), - detection_scores=np.array([0.8]), - detection_classes=np.array([1]), - label_names=["NULL", "nutriscore-a"], - ) - mocker.patch( - "robotoff.prediction.object_detection.core.ObjectDetectionModelRegistry.get", - return_value=FakeNutriscoreModel(raw_result), - ) - - insight = extract_nutriscore_label( - Image.Image, - source_image=source_image, - manual_threshold=0.5, - automatic_threshold=automatic_threshold, - ) - - assert insight == Prediction( - type=PredictionType.label, - data={ - "confidence": 0.8, - "bounding_box": (1, 2, 3, 4), - "model": "nutriscore", - "notify": True, - }, - source_image=source_image, - value_tag="en:nutriscore-grade-a", - automatic_processing=processed_automatically, - ) diff --git a/tests/unit/insights/test_importer.py b/tests/unit/insights/test_importer.py index 16de2cb59d..faab6d0a87 100644 --- a/tests/unit/insights/test_importer.py +++ b/tests/unit/insights/test_importer.py @@ -500,48 +500,49 @@ def test_get_insight_update_annotated_reference(self): assert to_delete == [] assert to_update == [] - def test_generate_insights_no_predictions(self): + def test_generate_insights_no_predictions(self, mocker): + get_existing_insight_mock = mocker.patch( + "robotoff.insights.importer.get_existing_insight", return_value=[] + ) assert ( - list( - InsightImporter.generate_insights( - [], - DEFAULT_SERVER_DOMAIN, - product_store=FakeProductStore(), - ) + CategoryImporter.generate_insights( + DEFAULT_BARCODE, + [], + DEFAULT_SERVER_DOMAIN, + product_store=FakeProductStore(), ) - == [] + == ([], [], []) ) + get_existing_insight_mock.assert_called_once() - def test_generate_insights_missing_product_no_references(self, mocker): + def test_generate_insights_no_predictions_with_existing_insight(self, mocker): + existing_insight = ProductInsight( + barcode=DEFAULT_BARCODE, + type=InsightType.category.name, + value_tag="en:fishes", + ) get_existing_insight_mock = mocker.patch( - "robotoff.insights.importer.get_existing_insight", return_value=[] + "robotoff.insights.importer.get_existing_insight", + return_value=[existing_insight], ) assert ( - list( - InsightImporter.generate_insights( - [ - Prediction( - type=PredictionType.category, - barcode=DEFAULT_BARCODE, - data={}, - ) - ], - DEFAULT_SERVER_DOMAIN, - product_store=FakeProductStore(), - ) + CategoryImporter.generate_insights( + DEFAULT_BARCODE, + [], + DEFAULT_SERVER_DOMAIN, + product_store=FakeProductStore(), ) - == [] + == ([], [], [existing_insight]) ) get_existing_insight_mock.assert_called_once() - def test_generate_insights_missing_product_with_reference(self, mocker): - reference = ProductInsight(barcode=DEFAULT_BARCODE, type=InsightType.category) + def test_generate_insights_missing_product_no_references(self, mocker): get_existing_insight_mock = mocker.patch( - "robotoff.insights.importer.get_existing_insight", - return_value=[reference], + "robotoff.insights.importer.get_existing_insight", return_value=[] ) - generated = list( + assert ( InsightImporter.generate_insights( + DEFAULT_BARCODE, [ Prediction( type=PredictionType.category, @@ -552,8 +553,29 @@ def test_generate_insights_missing_product_with_reference(self, mocker): DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) + == ([], [], []) + ) + get_existing_insight_mock.assert_called_once() + + def test_generate_insights_missing_product_with_reference(self, mocker): + reference = ProductInsight(barcode=DEFAULT_BARCODE, type=InsightType.category) + get_existing_insight_mock = mocker.patch( + "robotoff.insights.importer.get_existing_insight", + return_value=[reference], + ) + generated = InsightImporter.generate_insights( + DEFAULT_BARCODE, + [ + Prediction( + type=PredictionType.category, + barcode=DEFAULT_BARCODE, + data={}, + ) + ], + DEFAULT_SERVER_DOMAIN, + product_store=FakeProductStore(), ) - assert generated == [([], [], [reference])] + assert generated == ([], [], [reference]) get_existing_insight_mock.assert_called_once() def test_generate_insights_creation_and_deletion(self, mocker): @@ -591,31 +613,29 @@ def get_insight_update(cls, candidates, references): automatic_processing=True, source_image="/images/products/322/982/001/9192/8.jpg", ) - generated = list( - FakeImporter.generate_insights( - [prediction], - DEFAULT_SERVER_DOMAIN, - product_store=FakeProductStore( - data={ - DEFAULT_BARCODE: Product( - { - "code": DEFAULT_BARCODE, - "images": { - "8": { - "uploaded_t": ( - datetime.datetime.utcnow() - - datetime.timedelta(days=600) - ).timestamp() - } - }, - } - ) - } - ), - ) + generated = FakeImporter.generate_insights( + DEFAULT_BARCODE, + [prediction], + DEFAULT_SERVER_DOMAIN, + product_store=FakeProductStore( + data={ + DEFAULT_BARCODE: Product( + { + "code": DEFAULT_BARCODE, + "images": { + "8": { + "uploaded_t": ( + datetime.datetime.utcnow() + - datetime.timedelta(days=600) + ).timestamp() + } + }, + } + ) + } + ), ) - assert len(generated) == 1 - to_create, to_update, to_delete = generated[0] + to_create, to_update, to_delete = generated assert len(to_create) == 1 assert len(to_update) == 0 created_insight = to_create[0] @@ -655,17 +675,15 @@ def get_insight_update(cls, candidates, references): data={}, automatic_processing=True, ) - generated = list( - FakeImporter.generate_insights( - [prediction], - DEFAULT_SERVER_DOMAIN, - product_store=FakeProductStore( - data={DEFAULT_BARCODE: Product({"code": DEFAULT_BARCODE})} - ), - ) + generated = FakeImporter.generate_insights( + DEFAULT_BARCODE, + [prediction], + DEFAULT_SERVER_DOMAIN, + product_store=FakeProductStore( + data={DEFAULT_BARCODE: Product({"code": DEFAULT_BARCODE})} + ), ) - assert len(generated) == 1 - to_create, to_update, to_delete = generated[0] + to_create, to_update, to_delete = generated assert not to_delete assert to_update == [] assert len(to_create) == 1 @@ -680,6 +698,7 @@ def get_required_prediction_types(): with pytest.raises(ValueError, match="unexpected prediction type: 'label'"): FakeImporter.import_insights( + DEFAULT_BARCODE, [Prediction(type=PredictionType.label)], DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), @@ -692,26 +711,33 @@ def get_required_prediction_types(): return {PredictionType.label} @classmethod - def generate_insights(cls, predictions, server_domains, product_store): - yield [ - ProductInsight( - barcode=DEFAULT_BARCODE, - type=InsightType.label.name, - value_tag="tag1", - ) - ], [], [ - ProductInsight( - barcode=DEFAULT_BARCODE, - type=InsightType.label.name, - value_tag="tag2", - ) - ] + def generate_insights( + cls, barcode, predictions, server_domains, product_store + ): + return ( + [ + ProductInsight( + barcode=DEFAULT_BARCODE, + type=InsightType.label.name, + value_tag="tag1", + ) + ], + [], + [ + ProductInsight( + barcode=DEFAULT_BARCODE, + type=InsightType.label.name, + value_tag="tag2", + ) + ], + ) product_insight_delete_mock = mocker.patch.object(ProductInsight, "delete") batch_insert_mock = mocker.patch( "robotoff.insights.importer.batch_insert", return_value=1 ) imported = FakeImporter.import_insights( + DEFAULT_BARCODE, [Prediction(type=PredictionType.label)], DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), @@ -1187,9 +1213,7 @@ def test_import_insights_no_element(self, mocker): product_store=product_store, ) get_product_predictions_mock.assert_called_once() - import_insights_mock.assert_called_once_with( - [], DEFAULT_SERVER_DOMAIN, product_store - ) + import_insights_mock.assert_not_called() def test_import_insights_single_product(self, mocker): prediction_dict = { @@ -1221,7 +1245,7 @@ def test_import_insights_single_product(self, mocker): assert imported == 1 get_product_predictions_mock.assert_called_once() import_insights_mock.assert_called_once_with( - [prediction], DEFAULT_SERVER_DOMAIN, product_store + DEFAULT_BARCODE, [prediction], DEFAULT_SERVER_DOMAIN, product_store ) def test_import_insights_type_mismatch(self, mocker):