-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding_model.py
46 lines (38 loc) · 1.48 KB
/
embedding_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from tensorflow.keras import models
import numpy as np
import tensorflow as tf
from pathlib import Path
from matplotlib import pyplot as plt
from keras.applications import resnet
from keras import metrics
import os
target_shape = (200, 200)
class EmbeddingModel:
def __init__(self, filepath, target_shape=target_shape, preprocess_input=None):
self.model = models.load_model(filepath)
self.target_shape = target_shape
self.preprocess_input = preprocess_input
def preprocess_image(self, filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def l2_distance(self, vec1, vec2):
return np.linalg.norm(vec1 - vec2)
def extract_feat(self, img_path):
img = self.preprocess_image(img_path)
img = np.expand_dims(img, axis=0).copy()
return self.model(self.preprocess_input(img))[0]
def extract_feats(self, img_paths):
num = len(img_paths)
imgs = np.zeros((num,) + self.target_shape + (3,))
feats = []
for i in range(num):
imgs[i] = self.preprocess_image(img_paths[i])
feats = [i for i in self.model(self.preprocess_input(imgs))]
return feats