Files
HSAP/algorithms/lane_ufld/code/UFLD/predict0729.py
Chengfang Lu 7c43b44c57 feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 16:59:59 +08:00

140 lines
5.5 KiB
Python
Executable File

import torch, os, cv2, glob
from model.model import parsingNet
from utils.common import merge_config
from utils.dist_utils import dist_print
# import torch
import scipy.special, tqdm
import numpy as np
import torchvision.transforms as transforms
from data.dataset import LaneTestDataset
from data.constant import culane_row_anchor, tusimple_row_anchor
from scipy.optimize import curve_fit
from lane_show import is_in_poly, handle_point, poly_fitting, draw_values
import time
import PIL
import re
class PredictLane:
def __init__(self):
# super(PredictLane, self).__init__()
# self.img =img
self.cls_num_per_lane = 56
self.griding_num = 100
self.backbone = '34'
start_0 = time.time()
self.net = parsingNet(pretrained=False, backbone=self.backbone, cls_dim=(self.griding_num + 1, self.cls_num_per_lane, 4),
# use_aux=False).to(device)
use_aux=False).cuda() # we dont need auxiliary segmentation in testing
state_dict = torch.load('./model/curb_c599.pth', map_location='cuda')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
self.net.load_state_dict(compatible_state_dict, strict=False)
self.net.eval()
# net = torch.load(cfg.test_model, map_location='cuda')
end_0 = time.time()
count_0 = end_0 - start_0
print('the load net time is : ', count_0)
self.img_transforms = transforms.Compose([
transforms.Resize((288, 800)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.count = 0
def predict(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.count += 1
start_1 = time.time()
img_i = PIL.Image.fromarray(img.astype(np.uint8))
img_t = self.img_transforms(img_i)
img_w, img_h = 1280, 720
row_anchor = tusimple_row_anchor
# print(img_t)
img_t = img_t.reshape(1, 3, 288, 800)
# print(img_t)
# print(img_t.shape)
imgs = img_t.cuda()
end_1 = time.time()
count_1 = end_1 - start_1
print('the predeal time is : ', count_1)
# imgs = imgs.to(device)
with torch.no_grad():
start_t = time.time()
out = self.net(imgs)
# print(out[0].shape)
end_t = time.time()
count_t = end_t - start_t
print('the pre time is : ', count_t)
col_sample = np.linspace(0, 800 - 1, self.griding_num)
col_sample_w = col_sample[1] - col_sample[0]
out_j = out[0].data.cpu().numpy()
# print(out_j.shape, type(out_j))
out_j = out_j[:, ::-1, :]
prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
idx = np.arange(self.griding_num) + 1
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
out_j = np.argmax(out_j, axis=0)
loc[out_j == self.griding_num] = 0
out_j = loc
vis = img
lanes_list = []
for i in range(out_j.shape[1]):
# print(i)
points_list = []
# print(out_j.shape[1])
if np.sum(out_j[:, i] != 0) > 2:
# poly = [[400, 211], [23, 403], [930, 230], [1276, 442]] # ROI区域
poly = [[0, 0], [0, 720], [1280, 0], [1280, 720]] # ROI区域
lane_x = []
lane_y = []
for k in range(out_j.shape[0]):
# print(out_j.shape[0])6
if out_j[k, i] > 0:
ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1,
int(img_h * (row_anchor[self.cls_num_per_lane - 1 - k] / 288)) - 1)
is_in = is_in_poly(ppp, poly)
if is_in == True:
# 将处理后的点坐标添如一个空列表做拟合用
lane_x.append(ppp[0])
lane_y.append(ppp[1])
points_list.append((float(ppp[0]), float(ppp[1])))
cv2.circle(vis, ppp, 5, (0, 255, 0), -1)
lx, ly, rx, ry = handle_point(lane_x, lane_y)
# print('1111111111', lx, ly, rx, ry)
# curvature, distance_from_center = poly_fitting(lx, ly, rx, ry)
# draw_values(vis, curvature, distance_from_center)
# print(points_list)
# lane = np.uint8()
if points_list != []:
lanes_list.append(points_list)
cv2.imshow('vis', vis)
cv2.waitKey(1)
return lanes_list
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
# args, cfg = merge_config()
data_root = r'C:\data\curb_data\curb_data\2'
dist_print('start testing...')
backbone = '34'
# jpg_file = 'n_9\\frame6000.jpg'
pic_list = glob.glob(data_root + '/*.png')
pattern_number_oeder = '(\d*?).png'
pic_list.sort(key=lambda x: int(re.findall(pattern_number_oeder, x)[0]))
a = PredictLane()
# print(pic_list)
for pic in pic_list:
print(pic)
img = cv2.imread(pic)
lanes_list = a.predict(img)