-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
162 lines (129 loc) · 6.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import time
import darknet
import numpy as np
from model.lane_line import draw_lane_lines, draw_stop_line
from model.plate import LPR
from model.car import get_license_plate, speed_measure, draw_speed_info
from model.util.point_util import *
from model.conf import conf
from model.detect_color import traffic_light
from model.zebra import Zebra, get_zebra_line, draw_zebra_line
from model.comity_pedestrian import judge_comity_pedestrian, Comity_Pedestrian
from model.traffic_flow import get_traffic_flow, Traffic_Flow
from model.retrograde import get_retrograde_cars
from model.running_red_lights import judge_running_car
from model.illegal_parking import find_illegal_area, find_illegal_parking_cars
from model.save_img import save_illegal_car, create_save_file
import cv2
class Data:
tracks = [] # 对应追踪编号的轨迹
illegal_boxes_number = [] # 违规变道车的追踪编号
lane_lines = [] # 车道线
stop_line = [] # 停车线
lanes = [] # 车道
illegal_area = [] # 违停区域
illegal_parking_numbers = [] # 违停车辆编号
zebra_line = Zebra(0, 0, 0, 0) # 斑马线
speeds = [] # 速度信息
traffic_flow = 0 # 车流量
init_flag = True # 首次运行标志位
retrograde_cars_number = [] # 逆行车号
no_comity_pedestrian_cars_number = [] # 不礼让行人的车号
true_running_car = [] # 闯红灯车辆的追踪编号
running_car = []
origin = []
class_names = get_names(conf.names_path) # 标签名称
colors = get_colors(class_names) # 每个标签对应的颜色
class Model:
def __init__(self):
# 追踪器模型
self.encoder, self.tracker = init_deep_sort()
# darknet 模型
netMain = darknet.load_net_custom(conf.cfg_path.encode("ascii"), conf.weight_path.encode("ascii"), 0, 1)
metaMain = darknet.load_meta(conf.radar_data_path.encode("ascii"))
image_width = darknet.network_width(netMain)
image_height = darknet.network_height(netMain)
darknet_image = darknet.make_image(image_width, image_height, 3)
# 车牌识别模型
plate_model = LPR(conf.plate_cascade, conf.plate_model12, conf.plate_ocr_plate_all_gru)
def YOLO():
data = Data()
model = Model()
comity_pedestrian = Comity_Pedestrian()
traffic_flow = Traffic_Flow()
print("Starting the YOLO loop...")
cap = cv2.VideoCapture(conf.video_path)
while True:
prev_time = time.time()
ret, frame_read = cap.read()
if frame_read is None:
exit(0)
if data.init_flag:
create_save_file()
data.zebra_line = get_zebra_line(frame_read)
data.lane_lines, data.stop_line = lane_line.get_lane_lines(frame_read, data.zebra_line)
data.lanes = lane_line.get_lanes(frame_read, data.lane_lines)
data.illegal_area = find_illegal_area(frame_read, data.lanes, data.stop_line)
traffic_flow.pre_time = time.time()
data.init_flag = False
frame_rgb = cv2.cvtColor(frame_read, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (model.image_width, model.image_height), interpolation=cv2.INTER_LINEAR)
darknet.copy_image_from_bytes(model.darknet_image, frame_resized.tobytes())
# 类别编号 置信度 (x,y,w,h)
detections = darknet.detect_image(model.netMain, model.metaMain, model.darknet_image, thresh=conf.thresh)
# 类别编号, 置信度, 中点坐标, 左上坐标, 右下坐标, 追踪编号(-1为未确定), 类别数据(obj)
boxes = convert_output(detections)
# 更新tracker
boxes = tracker_update(boxes, frame_resized, model.encoder, model.tracker, conf.trackerConf.track_label)
# 把识别框映射为输入图片大小
boxes = cast_origin(boxes, model.image_width, model.image_height, frame_rgb.shape)
# 红绿灯的颜色放在box最后面
boxes = traffic_light(boxes, frame_read)
# 制作轨迹
make_track(boxes, data.tracks)
# 计算速度
speed_measure(data.tracks, float(time.time() - prev_time), data.speeds)
# 车牌识别
boxes = get_license_plate(boxes, frame_rgb, model.plate_model)
# 检测礼让行人
data.no_comity_pedestrian_cars_number = judge_comity_pedestrian(frame_rgb, data.tracks, comity_pedestrian)
#检测闯红灯
if boxes:
data.running_car, data.true_running_car = judge_running_car(frame_read, data.origin, data.running_car,
boxes, data.tracks,
data.stop_line, data.lane_lines)
# 检测违规变道
judge_illegal_change_lanes(frame_rgb, boxes, data.lane_lines, data.illegal_boxes_number)
# 检测车流量
data.traffic_flow = get_traffic_flow(frame_rgb, traffic_flow, data.tracks, time.time())
print("车流量为:%d" % data.traffic_flow)
# 检测逆行车辆
data.retrograde_cars_number = get_retrograde_cars(frame_rgb, data.lane_lines, data.tracks,
data.retrograde_cars_number)
# 检测违规停车
data.illegal_parking_numbers = find_illegal_parking_cars(data.illegal_area,
data.tracks,
data.illegal_parking_numbers)
# 保存违规车辆图片
save_illegal_car(frame_rgb, data, boxes)
# 画出预测结果
frame_rgb = draw_result(frame_rgb, boxes, data)
draw_zebra_line(frame_rgb, data.zebra_line)
draw_lane_lines(frame_rgb, data.lane_lines)
draw_stop_line(frame_rgb, data.stop_line)
# draw_speed_info(frame_rgb, data.speeds, boxes)
# 打印预测信息
# print_info(boxes, time.time() - prev_time, data.class_names)
# 显示图片
out_win = "result"
cv2.namedWindow(out_win, cv2.WINDOW_NORMAL)
frame_rgb = cv2.cvtColor(frame_rgb, cv2.COLOR_BGR2RGB)
cv2.imshow(out_win, frame_rgb)
key = cv2.waitKey(1)
if key == 27:
exit()
elif key >= 0:
cv2.waitKey(0)
if __name__ == "__main__":
YOLO()