Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat yolo classif #320

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from datetime import datetime
from logging.handlers import TimedRotatingFileHandler
from contextlib import asynccontextmanager
from typing import Union
from uuid import uuid4

Expand Down Expand Up @@ -159,12 +160,7 @@ async def add_owasp_middleware(request: Request, call_next):
logger = setup_logs(PATH_LOGS)

# Load model
MODEL_PATH = os.path.join(CURRENT_DIR, "../model.pt")
model = None
if os.path.exists(MODEL_PATH):
model = load_model_inference(MODEL_PATH)
if not model:
raise RuntimeError("Model not found")
app.model = load_model_inference("./model.pt")

# Object storage
S3_URL_ENDPOINT = os.environ["S3_URL_ENDPOINT"]
Expand All @@ -173,13 +169,6 @@ async def add_owasp_middleware(request: Request, call_next):

s3 = boto3.resource("s3", endpoint_url=S3_URL_ENDPOINT, verify=False)

""" TODO : check if connection successful
try:
s3.meta.client.head_bucket(Bucket=S3_BUCKET_NAME)
except ClientError:
logger.exception("Cannot find s3 bucket ! Are you sure your credentials are correct ?")
"""

# Versions
if "versions.json" in os.listdir(os.path.dirname(CURRENT_DIR)):
with open("versions.json", "r") as f:
Expand Down Expand Up @@ -238,13 +227,13 @@ async def imageupload(

# send image to model for prediction
start = time.time()
label, confidence = predict_image(model, img_bytes)
label, confidence = predict_image(app.model, img_bytes)
extras_logging["bg_label"] = label
extras_logging["bg_confidence"] = confidence
extras_logging["bg_model_time"] = round(time.time() - start, 2)
if confidence < 0.76:
extras_logging["bg_confidence_level"] = "low"
elif confidence < 0.99:
elif confidence < 0.98:
extras_logging["bg_confidence_level"] = "medium"
else:
extras_logging["bg_confidence_level"] = "high"
Expand Down
5 changes: 2 additions & 3 deletions backend/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def load_model_inference(model_path: str):
model_path (str): path to model (.pt file)

Returns:
Model: loaded model ready for prediction
Model: loaded model ready for prediction and Warm-up
"""
model = YOLO(model_path)
return model
return YOLO(model_path)


def predict_image(model, img: bytes) -> Union[str, float]:
Expand Down
11 changes: 3 additions & 8 deletions backend/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

import pytest
from src.model import CLASSES, load_model_inference, predict_image

from src.main import app

class TestModel:
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../model.pt")
assert os.path.exists(model_path)
model = load_model_inference(model_path)

def test_predict_image(self):
"""Checks the prediction of an image by the model"""
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "revolver.jpg")
with open(path, "rb") as f:
res = predict_image(self.model, f.read())
with open("./tests/revolver.jpg", "rb") as f:
res = predict_image(app.model, f.read())
assert res[0] == "revolver"
assert res[1] == pytest.approx(1, 0.1)
4 changes: 2 additions & 2 deletions frontend/src/components/ResultPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ function sendFeedback (isCorrect: boolean) {
<DsfrTag
class="fr-tag--sm success-tag"
>
Indice de fiabilité : {{ Math.floor(confidence) }}%
Indice de fiabilité : {{ Math.floor(confidence * 100) }}%
</DsfrTag>
</div>
<div v-else>
<DsfrTag
class="fr-tag--sm warning-tag"
>
Indice de fiabilité : {{ Math.floor(confidence) }}%
Indice de fiabilité : {{ Math.floor(confidence * 100) }}%
</DsfrTag>
<p class="warning-text">
Nous vous conseillons de faire appel à un expert pour confirmer cette réponse.
Expand Down
Loading