-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathkeypoint_classifier.py
34 lines (26 loc) · 1.04 KB
/
keypoint_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
from tflite_runtime import interpreter
class KeyPointClassifier(object):
def __init__(
self,
model_path='models/keypoint_classifier.tflite',
num_threads=1,
):
self.interpreter = interpreter.Interpreter(model_path=model_path,
num_threads=num_threads)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def __call__(
self,
landmark_list,
):
input_details_tensor_index = self.input_details[0]['index']
self.interpreter.set_tensor(
input_details_tensor_index,
np.array([landmark_list], dtype=np.float32))
self.interpreter.invoke()
output_details_tensor_index = self.output_details[0]['index']
result = self.interpreter.get_tensor(output_details_tensor_index)
result_index = np.argmax(np.squeeze(result))
return result_index