Skip to content

Commit

Permalink
Add black, isort and pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thomashbrnrd committed Oct 12, 2023
1 parent 93e2f7e commit a2e29e1
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 120 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
103 changes: 64 additions & 39 deletions backend/src/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@

import os
import json
import logging
import os
import time
import json
from logging.handlers import TimedRotatingFileHandler
from datetime import datetime
from uuid import uuid4
from logging.handlers import TimedRotatingFileHandler
from typing import Union
from uuid import uuid4

import boto3
from fastapi import BackgroundTasks, Cookie, FastAPI, APIRouter, File, Form, HTTPException, Request, Response, UploadFile
from fastapi.responses import PlainTextResponse
from fastapi import (
APIRouter,
BackgroundTasks,
Cookie,
FastAPI,
File,
Form,
HTTPException,
Request,
Response,
UploadFile,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from gelfformatter import GelfFormatter
from user_agents import parse
from src.model import load_model_inference, predict_image

from user_agents import parse

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -34,9 +43,9 @@ def init_variable(key: str, value: str) -> str:
VAR = os.environ[key]
else:
VAR = value
print("WARNING: The variable "+key+" is not set. Using", VAR)
print("WARNING: The variable " + key + " is not set. Using", VAR)
if os.path.isabs(VAR):
os.makedirs(VAR, exist_ok = True)
os.makedirs(VAR, exist_ok=True)
return VAR


Expand All @@ -57,10 +66,8 @@ def setup_logs(log_dir: str) -> logging.Logger:
# new log file at midnight
log_file = os.path.join(log_dir, "log.json")
handler = TimedRotatingFileHandler(
log_file,
when="midnight",
interval=1,
backupCount=7)
log_file, when="midnight", interval=1, backupCount=7
)
logger.setLevel(logging.INFO)
handler.setFormatter(formatter)
logger.addHandler(handler)
Expand Down Expand Up @@ -111,7 +118,6 @@ def get_base_logs(user_agent, user_id: str) -> dict:
return extras_logging



def upload_image(content: bytes, image_key: str):
"""Uploads an image to s3 bucket
path uploaded-images/WORKSPACE/img_name
Expand All @@ -126,8 +132,8 @@ def upload_image(content: bytes, image_key: str):
object.put(Body=content)
extras_logging = {
"bg_date": datetime.now().isoformat(),
"bg_upload_time": time.time()-start,
"bg_image_url": image_key
"bg_upload_time": time.time() - start,
"bg_image_url": image_key,
}
logger.info("Upload successful", extra=extras_logging)

Expand All @@ -140,7 +146,7 @@ def upload_image(content: bytes, image_key: str):
app = FastAPI(docs_url="/api/docs")
router = APIRouter(prefix="/api")

origins = [ # allow requests from front-end
origins = [ # allow requests from front-end
"http://basegun.fr",
"https://basegun.fr",
"http://preprod.basegun.fr",
Expand All @@ -158,14 +164,13 @@ def upload_image(content: bytes, image_key: str):
)

# Logs
PATH_LOGS = init_variable("PATH_LOGS",
os.path.abspath(os.path.join(CURRENT_DIR,"/tmp/logs")))
PATH_LOGS = init_variable(
"PATH_LOGS", os.path.abspath(os.path.join(CURRENT_DIR, "/tmp/logs"))
)
logger = setup_logs(PATH_LOGS)

# Load model
MODEL_PATH = os.path.join(
CURRENT_DIR,
"weights/model.pth")
MODEL_PATH = os.path.join(CURRENT_DIR, "weights/model.pth")
model = None
if os.path.exists(MODEL_PATH):
model = load_model_inference(MODEL_PATH)
Expand All @@ -175,7 +180,7 @@ def upload_image(content: bytes, image_key: str):
# Object storage
S3_URL_ENDPOINT = init_variable("S3_URL_ENDPOINT", "https://s3.gra.io.cloud.ovh.net/")
S3_BUCKET_NAME = "basegun-s3"
S3_PREFIX = os.path.join("uploaded-images/", os.environ['WORKSPACE'])
S3_PREFIX = os.path.join("uploaded-images/", os.environ["WORKSPACE"])
s3 = boto3.resource("s3", endpoint_url=S3_URL_ENDPOINT)
""" TODO : check if connection successful
try:
Expand Down Expand Up @@ -229,7 +234,8 @@ async def imageupload(
image: UploadFile = File(...),
date: float = Form(...),
geolocation: str = Form(...),
user_id: Union[str, None] = Cookie(None) ):
user_id: Union[str, None] = Cookie(None),
):

# prepare content logs
user_agent = parse(request.headers.get("user-agent"))
Expand All @@ -238,8 +244,9 @@ async def imageupload(
extras_logging["bg_upload_time"] = round(time.time() - date, 2)

try:
img_key = os.path.join(S3_PREFIX,
str(uuid4()) + os.path.splitext(image.filename)[1].lower())
img_key = os.path.join(
S3_PREFIX, str(uuid4()) + os.path.splitext(image.filename)[1].lower()
)
img_bytes = image.file.read()

# upload image to OVH Cloud
Expand All @@ -257,7 +264,7 @@ async def imageupload(
label, confidence = predict_image(model, img_bytes)
extras_logging["bg_label"] = label
extras_logging["bg_confidence"] = confidence
extras_logging["bg_model_time"] = round(time.time()-start, 2)
extras_logging["bg_model_time"] = round(time.time() - start, 2)
if confidence < 46:
extras_logging["bg_confidence_level"] = "low"
elif confidence < 76:
Expand All @@ -271,7 +278,7 @@ async def imageupload(
"path": img_key,
"label": label,
"confidence": confidence,
"confidence_level": extras_logging["bg_confidence_level"]
"confidence_level": extras_logging["bg_confidence_level"],
}

except Exception as e:
Expand All @@ -289,40 +296,58 @@ async def log_feedback(request: Request, user_id: Union[str, None] = Cookie(None

extras_logging["bg_feedback_bool"] = res["feedback"]
for key in ["image_url", "label", "confidence", "confidence_level"]:
extras_logging["bg_"+key] = res[key]
extras_logging["bg_" + key] = res[key]

logger.info("Identification feedback", extra=extras_logging)
return


@router.post("/tutorial-feedback")
async def log_tutorial_feedback(request: Request, user_id: Union[str, None] = Cookie(None)):
async def log_tutorial_feedback(
request: Request, user_id: Union[str, None] = Cookie(None)
):
res = await request.json()

user_agent = parse(request.headers.get("user-agent"))
extras_logging = get_base_logs(user_agent, user_id)

for key in ["image_url", "label", "confidence", "confidence_level",
"tutorial_feedback", "tutorial_option", "route_name"]:
extras_logging["bg_"+key] = res[key]
for key in [
"image_url",
"label",
"confidence",
"confidence_level",
"tutorial_feedback",
"tutorial_option",
"route_name",
]:
extras_logging["bg_" + key] = res[key]

logger.info("Tutorial feedback", extra=extras_logging)
return


@router.post("/identification-dummy")
async def log_identification_dummy(request: Request, user_id: Union[str, None] = Cookie(None)):
async def log_identification_dummy(
request: Request, user_id: Union[str, None] = Cookie(None)
):
res = await request.json()

user_agent = parse(request.headers.get("user-agent"))
extras_logging = get_base_logs(user_agent, user_id)

# to know if the firearm is dummy or real
extras_logging["bg_dummy_bool"] = res["is_dummy"]
for key in ["image_url", "label", "confidence", "confidence_level", "tutorial_option"]:
extras_logging["bg_"+key] = res[key]
for key in [
"image_url",
"label",
"confidence",
"confidence_level",
"tutorial_option",
]:
extras_logging["bg_" + key] = res[key]

logger.info("Identification dummy", extra=extras_logging)
return

app.include_router(router)

app.include_router(router)
86 changes: 48 additions & 38 deletions backend/src/model.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
from io import BytesIO
from typing import Union
from PIL import Image

import numpy as np
import torch
import torchvision.models as Model
from PIL import Image
from torchvision import transforms


CLASSES = ['autre_pistolet',
'epaule_a_levier_sous_garde',
'epaule_a_pompe',
'epaule_a_un_coup_par_canon',
'epaule_a_verrou',
'epaule_mecanisme_ancien',
'epaule_semi_auto_style_chasse',
'epaule_semi_auto_style_militaire_milieu_20e',
'pistolet_mecanisme_ancien',
'pistolet_semi_auto_moderne',
'revolver',
'semi_auto_style_militaire_autre']
CLASSES = [
"autre_pistolet",
"epaule_a_levier_sous_garde",
"epaule_a_pompe",
"epaule_a_un_coup_par_canon",
"epaule_a_verrou",
"epaule_mecanisme_ancien",
"epaule_semi_auto_style_chasse",
"epaule_semi_auto_style_militaire_milieu_20e",
"pistolet_mecanisme_ancien",
"pistolet_semi_auto_moderne",
"revolver",
"semi_auto_style_militaire_autre",
]

MODEL_TORCH = Model.efficientnet_b7
INPUT_SIZE = 600
device = torch.device('cpu')
device = torch.device("cpu")


class ConvertRgb(object):
"""Converts an image to RGB
"""
"""Converts an image to RGB"""

def __init__(self):
pass

def __call__(self, image):
if image.mode != 'RGB':
image = image.convert('RGB')
if image.mode != "RGB":
image = image.convert("RGB")
return image


Expand Down Expand Up @@ -74,18 +75,23 @@ def __init__(self, output_size):

def __call__(self, image):
w, h = image.size
pads = {'horiz': [self.output_size - w,0,0],
'vert': [self.output_size - h,0,0]}
if pads['horiz'][0] >= 0 and pads['vert'][0] >= 0:
for direction in ['horiz', 'vert'] :
pads = {
"horiz": [self.output_size - w, 0, 0],
"vert": [self.output_size - h, 0, 0],
}
if pads["horiz"][0] >= 0 and pads["vert"][0] >= 0:
for direction in ["horiz", "vert"]:
pads[direction][1] = pads[direction][0] // 2
if pads[direction][0] % 2 == 1: # if the size to pad is odd, add a random +1 on one side
pads[direction][1] += np.random.randint(0,1)
if (
pads[direction][0] % 2 == 1
): # if the size to pad is odd, add a random +1 on one side
pads[direction][1] += np.random.randint(0, 1)
pads[direction][2] = pads[direction][0] - pads[direction][1]

return transforms.functional.pad(image,
[pads['horiz'][1], pads['vert'][1], pads['horiz'][2], pads['vert'][2]],
fill = int(np.random.choice([0, 255])) # border randomly white or black
return transforms.functional.pad(
image,
[pads["horiz"][1], pads["vert"][1], pads["horiz"][2], pads["vert"][2]],
fill=int(np.random.choice([0, 255])), # border randomly white or black
)
else:
return image
Expand Down Expand Up @@ -121,7 +127,9 @@ def load_model_inference(state_dict_path: str) -> Model:
"""
model = build_model(MODEL_TORCH())
# Initialize model with the pretrained weights
model.load_state_dict(torch.load(state_dict_path, map_location=device)['model_state_dict'])
model.load_state_dict(
torch.load(state_dict_path, map_location=device)["model_state_dict"]
)
model.to(device)
# set the model to inference mode
model.eval()
Expand All @@ -138,13 +146,15 @@ def prepare_input(image: Image) -> torch.Tensor:
torch.Tensor: converted image
(shape (1, 3, size, size), normalized on ImageNet)
"""
loader = transforms.Compose([
ConvertRgb(),
Rescale(INPUT_SIZE),
RandomPad(INPUT_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
loader = transforms.Compose(
[
ConvertRgb(),
Rescale(INPUT_SIZE),
RandomPad(INPUT_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = loader(image).float()
return image.unsqueeze(0).to(device)

Expand All @@ -163,6 +173,6 @@ def predict_image(model: Model, img: bytes) -> Union[str, float]:
image = prepare_input(im)
output = model(image)
probs = torch.nn.functional.softmax(output, dim=1).detach().numpy()[0]
res = [(CLASSES[i], round(probs[i]*100,2)) for i in range(len(CLASSES))]
res.sort(key=lambda x:x[1], reverse=True)
res = [(CLASSES[i], round(probs[i] * 100, 2)) for i in range(len(CLASSES))]
res.sort(key=lambda x: x[1], reverse=True)
return res[0]
Loading

0 comments on commit a2e29e1

Please sign in to comment.