Skip to content

Commit

Permalink
Fix tests and Warmup path
Browse files Browse the repository at this point in the history
  • Loading branch information
thomashbrnrd committed Jan 23, 2024
1 parent f78c5bc commit dd67c7c
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 32 deletions.
Binary file removed backend/src/Warmup.jpg
Binary file not shown.
21 changes: 5 additions & 16 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler
from contextlib import asynccontextmanager
from typing import Union
from uuid import uuid4

Expand Down Expand Up @@ -159,12 +160,7 @@ async def add_owasp_middleware(request: Request, call_next):
logger = setup_logs(PATH_LOGS)

# Load model
MODEL_PATH = os.path.join(CURRENT_DIR, "../model.pt")
model = None
if os.path.exists(MODEL_PATH):
model = load_model_inference(MODEL_PATH)
if not model:
raise RuntimeError("Model not found")
app.model = load_model_inference("./model.pt")

# Object storage
S3_URL_ENDPOINT = os.environ["S3_URL_ENDPOINT"]
Expand All @@ -173,13 +169,6 @@ async def add_owasp_middleware(request: Request, call_next):

s3 = boto3.resource("s3", endpoint_url=S3_URL_ENDPOINT, verify=False)

""" TODO : check if connection successful
try:
s3.meta.client.head_bucket(Bucket=S3_BUCKET_NAME)
except ClientError:
logger.exception("Cannot find s3 bucket ! Are you sure your credentials are correct ?")
"""

# Versions
if "versions.json" in os.listdir(os.path.dirname(CURRENT_DIR)):
with open("versions.json", "r") as f:
Expand Down Expand Up @@ -238,13 +227,13 @@ async def imageupload(

# send image to model for prediction
start = time.time()
label, confidence = predict_image(model, img_bytes)
label, confidence = predict_image(app.model, img_bytes)
extras_logging["bg_label"] = label
extras_logging["bg_confidence"] = confidence
extras_logging["bg_model_time"] = round(time.time() - start, 2)
if confidence < 76:
if confidence < 0.76:
extras_logging["bg_confidence_level"] = "low"
elif confidence < 99:
elif confidence < 0.98:
extras_logging["bg_confidence_level"] = "medium"
else:
extras_logging["bg_confidence_level"] = "high"
Expand Down
9 changes: 3 additions & 6 deletions backend/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ def load_model_inference(model_path: str):
model_path (str): path to model (.pt file)
Returns:
Model: loaded model ready for prediction and Warmud-up
Model: loaded model ready for prediction and Warm-up
"""
model = YOLO(model_path)
test = Image.open("./Warmup.jpg")
model(test, verbose=False)
return model
return YOLO(model_path)


def predict_image(model, img: bytes) -> Union[str, float]:
Expand All @@ -51,4 +48,4 @@ def predict_image(model, img: bytes) -> Union[str, float]:
predicted_class = results[0].probs.top5[0]
label = CLASSES[predicted_class]
confidence = float(results[0].probs.top5conf[0])
return (label, 100*confidence)
return (label, confidence)
Binary file removed backend/tests/Warmup.jpg
Binary file not shown.
11 changes: 3 additions & 8 deletions backend/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

import pytest
from src.model import CLASSES, load_model_inference, predict_image

from src.main import app

class TestModel:
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../model.pt")
assert os.path.exists(model_path)
model = load_model_inference(model_path)

def test_predict_image(self):
"""Checks the prediction of an image by the model"""
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "revolver.jpg")
with open(path, "rb") as f:
res = predict_image(self.model, f.read())
with open("./tests/revolver.jpg", "rb") as f:
res = predict_image(app.model, f.read())
assert res[0] == "revolver"
assert res[1] == pytest.approx(1, 0.1)
4 changes: 2 additions & 2 deletions frontend/src/components/ResultPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ function sendFeedback (isCorrect: boolean) {
<DsfrTag
class="fr-tag--sm success-tag"
>
Indice de fiabilité : {{ Math.floor(confidence) }}%
Indice de fiabilité : {{ Math.floor(confidence * 100) }}%
</DsfrTag>
</div>
<div v-else>
<DsfrTag
class="fr-tag--sm warning-tag"
>
Indice de fiabilité : {{ Math.floor(confidence) }}%
Indice de fiabilité : {{ Math.floor(confidence * 100) }}%
</DsfrTag>
<p class="warning-text">
Nous vous conseillons de faire appel à un expert pour confirmer cette réponse.
Expand Down

0 comments on commit dd67c7c

Please sign in to comment.