-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocal_app.py
53 lines (37 loc) · 1.6 KB
/
local_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from flask import Flask, request, render_template
import torch
from PIL import Image
from transformers import AutoFeatureExtractor, ResNetForImageClassification
app = Flask(__name__)
# Load the pre-trained model from Hugging Face
model_name = "microsoft/resnet-18" # Replace with the name of your pre-trained model
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-18")
model.eval()
# Define a Flask API route to accept images
@app.route("/predict", methods=["POST"])
def predict():
print("Received request for prediction")
# show the image that the user has sent
print(request.files["image"])
# Get the image from the request
image_file = request.files["image"]
image = Image.open(image_file)
image = image.resize((224, 224))
inputs = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(f"Predicted class: {model.config.id2label[predicted_label]}")
# Return the prediction to the user
# return jsonify({'prediction': model.config.id2label[predicted_label]})
return model.config.id2label[predicted_label]
# Define a Flask API route to serve the HTML file
@app.route("/")
def index():
return render_template("index.html")
if __name__ == "__main__":
app.run(debug=True)