Skip to content

Commit

Permalink
Improve performance of Prediction import
Browse files Browse the repository at this point in the history
Retrieve existing prediction in batch (same barcode).
  • Loading branch information
Mithridatea committed Jan 28, 2022
1 parent be840a8 commit 1e63ca2
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,23 +630,6 @@ def is_valid_product_predictions(
return True


def is_duplicated_prediction(
prediction: Prediction, product_predictions: ProductPredictions, server_domain: str
):
return bool(
PredictionModel.select()
.where(
PredictionModel.barcode == product_predictions.barcode,
PredictionModel.type == product_predictions.type,
PredictionModel.server_domain == server_domain,
PredictionModel.source_image == product_predictions.source_image,
PredictionModel.value_tag == prediction.value_tag,
PredictionModel.value == prediction.value,
)
.count()
)


def create_prediction_model(
prediction: Prediction,
product_predictions: ProductPredictions,
Expand All @@ -668,19 +651,49 @@ def create_prediction_model(


def import_product_predictions(
product_predictions_iter: Iterable[ProductPredictions], server_domain: str
barcode: str,
product_predictions_iter: Iterable[ProductPredictions],
server_domain: str,
):
"""Import product predictions.
If a prediction already exists in DB (same (barcode, type, server_domain,
source_image, value, value_tag)), it won't be imported.
:param barcode: Barcode of the product. All `product_predictions` must
have the same barcode.
:param product_predictions_iter: Iterable of ProductPredictions.
:param server_domain: The server domain associated with the predictions.
:return: The number of items imported in DB.
"""
timestamp = datetime.datetime.utcnow()
existing_predictions = set(
PredictionModel.select(
PredictionModel.type,
PredictionModel.server_domain,
PredictionModel.source_image,
PredictionModel.value_tag,
PredictionModel.value,
)
.where(PredictionModel.barcode == barcode)
.tuples()
)

to_import = itertools.chain.from_iterable(
(
(
create_prediction_model(
prediction, product_predictions, server_domain, timestamp
)
for prediction in product_predictions.predictions
if not is_duplicated_prediction(
prediction, product_predictions, server_domain
if (
product_predictions.type,
server_domain,
product_predictions.source_image,
prediction.value_tag,
prediction.value,
)
not in existing_predictions
)
for product_predictions in product_predictions_iter
)
Expand Down Expand Up @@ -739,9 +752,15 @@ def import_insights(
product_predictions = [
p for p in product_predictions if is_valid_product_predictions(p, product_store)
]
predictions_imported = import_product_predictions(
product_predictions, server_domain
)

predictions_imported = 0
for barcode, product_predictions_group in itertools.groupby(
sorted(product_predictions, key=operator.attrgetter("barcode")),
operator.attrgetter("barcode"),
):
predictions_imported += import_product_predictions(
barcode, product_predictions_group, server_domain
)
logger.info(f"{predictions_imported} predictions imported")

prediction_type: PredictionType
Expand Down

0 comments on commit 1e63ca2

Please sign in to comment.