Huaxu Sentinel Active Safety Platform with embedded algorithm code, Docker Compose setup, and vendored dataset scaffolds for clone-and-run. Co-authored-by: Cursor <cursoragent@cursor.com>
140 lines
5.5 KiB
Python
Executable File
140 lines
5.5 KiB
Python
Executable File
import torch, os, cv2, glob
|
|
from model.model import parsingNet
|
|
from utils.common import merge_config
|
|
from utils.dist_utils import dist_print
|
|
# import torch
|
|
import scipy.special, tqdm
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
from data.dataset import LaneTestDataset
|
|
from data.constant import culane_row_anchor, tusimple_row_anchor
|
|
from scipy.optimize import curve_fit
|
|
from lane_show import is_in_poly, handle_point, poly_fitting, draw_values
|
|
import time
|
|
import PIL
|
|
import re
|
|
|
|
|
|
class PredictLane:
|
|
def __init__(self):
|
|
# super(PredictLane, self).__init__()
|
|
# self.img =img
|
|
self.cls_num_per_lane = 56
|
|
self.griding_num = 100
|
|
self.backbone = '34'
|
|
start_0 = time.time()
|
|
self.net = parsingNet(pretrained=False, backbone=self.backbone, cls_dim=(self.griding_num + 1, self.cls_num_per_lane, 4),
|
|
# use_aux=False).to(device)
|
|
use_aux=False).cuda() # we dont need auxiliary segmentation in testing
|
|
|
|
state_dict = torch.load('./model/curb_c599.pth', map_location='cuda')['model']
|
|
compatible_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
if 'module.' in k:
|
|
compatible_state_dict[k[7:]] = v
|
|
else:
|
|
compatible_state_dict[k] = v
|
|
|
|
self.net.load_state_dict(compatible_state_dict, strict=False)
|
|
self.net.eval()
|
|
# net = torch.load(cfg.test_model, map_location='cuda')
|
|
end_0 = time.time()
|
|
count_0 = end_0 - start_0
|
|
print('the load net time is : ', count_0)
|
|
self.img_transforms = transforms.Compose([
|
|
transforms.Resize((288, 800)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
])
|
|
self.count = 0
|
|
|
|
def predict(self, img):
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
self.count += 1
|
|
start_1 = time.time()
|
|
img_i = PIL.Image.fromarray(img.astype(np.uint8))
|
|
img_t = self.img_transforms(img_i)
|
|
img_w, img_h = 1280, 720
|
|
row_anchor = tusimple_row_anchor
|
|
# print(img_t)
|
|
img_t = img_t.reshape(1, 3, 288, 800)
|
|
# print(img_t)
|
|
# print(img_t.shape)
|
|
imgs = img_t.cuda()
|
|
end_1 = time.time()
|
|
count_1 = end_1 - start_1
|
|
print('the predeal time is : ', count_1)
|
|
# imgs = imgs.to(device)
|
|
with torch.no_grad():
|
|
start_t = time.time()
|
|
out = self.net(imgs)
|
|
# print(out[0].shape)
|
|
end_t = time.time()
|
|
count_t = end_t - start_t
|
|
print('the pre time is : ', count_t)
|
|
col_sample = np.linspace(0, 800 - 1, self.griding_num)
|
|
col_sample_w = col_sample[1] - col_sample[0]
|
|
out_j = out[0].data.cpu().numpy()
|
|
# print(out_j.shape, type(out_j))
|
|
out_j = out_j[:, ::-1, :]
|
|
prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
|
|
idx = np.arange(self.griding_num) + 1
|
|
idx = idx.reshape(-1, 1, 1)
|
|
loc = np.sum(prob * idx, axis=0)
|
|
out_j = np.argmax(out_j, axis=0)
|
|
loc[out_j == self.griding_num] = 0
|
|
out_j = loc
|
|
vis = img
|
|
lanes_list = []
|
|
for i in range(out_j.shape[1]):
|
|
# print(i)
|
|
points_list = []
|
|
# print(out_j.shape[1])
|
|
if np.sum(out_j[:, i] != 0) > 2:
|
|
# poly = [[400, 211], [23, 403], [930, 230], [1276, 442]] # ROI区域
|
|
poly = [[0, 0], [0, 720], [1280, 0], [1280, 720]] # ROI区域
|
|
lane_x = []
|
|
lane_y = []
|
|
for k in range(out_j.shape[0]):
|
|
# print(out_j.shape[0])6
|
|
if out_j[k, i] > 0:
|
|
ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1,
|
|
int(img_h * (row_anchor[self.cls_num_per_lane - 1 - k] / 288)) - 1)
|
|
|
|
is_in = is_in_poly(ppp, poly)
|
|
if is_in == True:
|
|
# 将处理后的点坐标添如一个空列表做拟合用
|
|
lane_x.append(ppp[0])
|
|
lane_y.append(ppp[1])
|
|
points_list.append((float(ppp[0]), float(ppp[1])))
|
|
cv2.circle(vis, ppp, 5, (0, 255, 0), -1)
|
|
lx, ly, rx, ry = handle_point(lane_x, lane_y)
|
|
# print('1111111111', lx, ly, rx, ry)
|
|
# curvature, distance_from_center = poly_fitting(lx, ly, rx, ry)
|
|
# draw_values(vis, curvature, distance_from_center)
|
|
# print(points_list)
|
|
# lane = np.uint8()
|
|
if points_list != []:
|
|
lanes_list.append(points_list)
|
|
cv2.imshow('vis', vis)
|
|
cv2.waitKey(1)
|
|
return lanes_list
|
|
|
|
|
|
if __name__ == "__main__":
|
|
torch.backends.cudnn.benchmark = True
|
|
# args, cfg = merge_config()
|
|
data_root = r'C:\data\curb_data\curb_data\2'
|
|
dist_print('start testing...')
|
|
backbone = '34'
|
|
# jpg_file = 'n_9\\frame6000.jpg'
|
|
pic_list = glob.glob(data_root + '/*.png')
|
|
pattern_number_oeder = '(\d*?).png'
|
|
pic_list.sort(key=lambda x: int(re.findall(pattern_number_oeder, x)[0]))
|
|
a = PredictLane()
|
|
# print(pic_list)
|
|
for pic in pic_list:
|
|
print(pic)
|
|
img = cv2.imread(pic)
|
|
lanes_list = a.predict(img)
|