Files
HSAP/algorithms/lane_ufld/code.embedded.bak/UFLD/pth2onnx_and_check.py

226 lines
7.7 KiB
Python
Raw Normal View History

# encoding: utf-8
import time
import tkinter.filedialog
import cv2
import numpy as np
import onnx
import onnxruntime
import scipy.special
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn
import torchvision.transforms as transforms
from PIL import Image as imim
from torch import nn
from data.constant import tusimple_row_anchor
from model.model import parsingNet
from utils.common import merge_config
import glob
# 一些预定设置
torch.backends.cudnn.benchmark = True
args, cfg = merge_config()
griding_num = 100
img_w, img_h = 1280, 720
row_anchor = tusimple_row_anchor
cls_num_per_lane = 56
model_path = "/home/ljk/桌面/ClrNet_onnx_modify/UFLD/ep327_ufld.pth"
# model_path = "/home/ljk/桌面/ep015_2ch_res18.pth"
folder_path = "/home/ljk/桌面/pic_1228_lane/pic1_720/*.jpg"
def func(x, a, b, c, d):
return a * pow(x, 3) + b * pow(x, 2) + c * x + d
class ufld_2(nn.Module):
def __init__(self):
super().__init__()
# self.path = "/home/ljk/桌面/semantic_code/1.lane_detect/3.model_weight/1.pth/shuangyujing_model_SpeedUp.pth"
# print(self.path)
# self.net = torch.load(self.path)
backbone_ = '18'
self.net = parsingNet(pretrained=False, backbone=backbone_, cls_dim=(cfg.griding_num + 1, 56, 2),
use_aux=False).cuda()
self.path = model_path
print(self.path)
state_dict = torch.load(self.path, map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
self.net.load_state_dict(compatible_state_dict, strict=False)
self.net.eval()
# self.onnx_model_name = self.path[:-4] + ".onnx"
self.onnx_model_name = "/home/ljk/桌面/ClrNet_onnx_modify/UFLD/ep327_ufld.onnx"
def export_onnx(self):
x = torch.randn(1, 3, 288, 800).cuda()
# torch.onnx.export 可以检查剪枝模型出了什么错
with torch.no_grad():
torch.onnx.export(self.net, x, self.onnx_model_name, export_params=True, opset_version=11,
input_names=['input'],
output_names=['output'])
print('\n', "start check...")
onnx_model = onnx.load(self.onnx_model_name)
try:
onnx.checker.check_model(onnx_model)
except:
print("Model incorrect")
else:
print("Model correct")
def output_cmp(self, img_path):
start = time.time()
img = imim.open(img_path)
img_transforms = transforms.Compose([
transforms.Resize((288, 800)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = img_transforms(img).unsqueeze_(0).cuda()
# 原模型的输出-------------------
with torch.no_grad():
out = self.net(img)
pth_output = to_numpy(out)
end = time.time()
cost_time = end - start
print("pth_cost_time:", cost_time * 1000, " ms")
print("原模型前5*4个向量输出\n", pth_output[0][0][:5])
start = time.time()
ort_session = onnxruntime.InferenceSession(self.onnx_model_name)
end = time.time()
cost_time = end - start
print("onnx_cost_time:", cost_time * 1000, " ms")
ort_inputs = {'input': to_numpy(img)}
ort_out = ort_session.run(['output'], ort_inputs)[0]
print("onnx的输出大小", ort_out.shape)
print("onnx前5*4个向量输出\n", ort_out[0][0][:5])
# np.testing.assert_allclose(pth_output, ort_out, rtol=1e-01, atol=1e-03)
print("*****************************")
print("精度对比结束,满足精度!")
print("*****************************")
# _ = pro_process(pth_output, img_path)
print(img_path)
_ = pro_process(ort_out, img_path)
# cv2.namedWindow('vis_onnx')
# cv2.imshow("vis_onnx", pro_process(ort_out, img_path))
def get_path():
filename = tkinter.filedialog.askopenfilename()
return filename
# 工具函数tensor->numpy
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# 工具函数:后处理与可视化
def pro_process(out, path):
col_sample = np.linspace(0, 800 - 1, griding_num)
col_sample_w = col_sample[1] - col_sample[0]
out_j = out[0]
out_j = out_j[:, ::-1, :]
prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
idx = np.arange(griding_num) + 1
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
out_j = np.argmax(out_j, axis=0)
loc[out_j == griding_num] = 0
out_j = loc
vis = cv2.imread(path)
left_points = []
right_points = []
for i in range(out_j.shape[1]):
if np.sum(out_j[:, i] != 0) > 2:
if i == 1:
print("Left Line")
else:
print("Right Line")
count = 0
for k in range(out_j.shape[0]):
if out_j[k][i] > 0:
x = int(out_j[k][i] * col_sample_w * img_w / 800) - 1
y = int(img_h * (row_anchor[cls_num_per_lane - 1 - k] / 288)) - 1
ppp = (x, y)
# print(ppp)
if i == 0:
left_points.append(ppp)
else:
right_points.append(ppp)
cv2.circle(vis, ppp, 2, (255, 0, 0), -1)
count += 1
print(count)
print("\n")
# if len(left_points) == 0 or len(right_points) == 0:
# return vis
# left_points_x = np.reshape(left_points, (len(left_points), -1))[:, 1]
# left_points_y = np.reshape(left_points, (len(left_points), -1))[:, 0]
# right_points_x = np.reshape(right_points, (len(right_points), -1))[:, 1]
# right_points_y = np.reshape(right_points, (len(right_points), -1))[:, 0]
# popt_left, _ = curve_fit(func, left_points_x, left_points_y)
# popt_right, _ = curve_fit(func, right_points_x, right_points_y)
# right_points_y_func = []
# left_points_y_func = []
# for i in range(len(left_points_x)):
# ppp = (int(func(left_points_x[i], popt_left[0], popt_left[1], popt_left[2], popt_left[3])),
# left_points_x[i])
# # cv2.circle(vis, ppp, 5, (0, 255, 0), -1)
# left_points_y_func.append(
# int(func(left_points_x[i], popt_left[0], popt_left[1], popt_left[2], popt_left[3])))
# for i in range(len(right_points_x)):
# ppp = (
# int(func(right_points_x[i], popt_right[0], popt_right[1], popt_right[2], popt_right[3])),
# right_points_x[i])
# # cv2.circle(vis, ppp, 5, (0, 255, 0), -1)
# right_points_y_func.append(
# int(func(right_points_x[i], popt_right[0], popt_right[1], popt_right[2], popt_right[3])))
# left_points_y_func = np.reshape(left_points_y_func, (len(left_points_y_func), -1))
# right_points_y_func = np.reshape(right_points_y_func, (len(right_points_y_func), -1))
# left_points_y_func = np.squeeze(left_points_y_func)
# right_points_y_func = np.squeeze(right_points_y_func)
cv2.namedWindow('vis_onnx', 0)
cv2.imshow("vis_onnx", vis)
cv2.waitKey(0)
# cv2.destroyAllWindows()
# cv2.imwrite(path, vis)
return vis
def main():
files = glob.glob(folder_path)
files = sorted(files)
model = ufld_2()
model.export_onnx()
for i in range(len(files)):
model.output_cmp(files[i])
if __name__ == '__main__':
main()