Files
HSAP/algorithms/lane_ufld/code/UFLD/evaluation/eval_wrapper.py

245 lines
12 KiB
Python
Raw Normal View History

from data.dataloader import get_test_loader
from evaluation.tusimple.lane import LaneEval
from utils.dist_utils import is_main_process, dist_print, get_rank, get_world_size, dist_tqdm, synchronize
import os, json, torch, scipy
import numpy as np
import platform
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generate_lines(out, shape, names, output_path, griding_num, localization_type='abs', flip_updown=False):
col_sample = np.linspace(0, shape[1] - 1, griding_num)
col_sample_w = col_sample[1] - col_sample[0]
for j in range(out.shape[0]):
out_j = out[j].data.cpu().numpy()
if flip_updown:
out_j = out_j[:, ::-1, :]
if localization_type == 'abs':
out_j = np.argmax(out_j, axis=0)
out_j[out_j == griding_num] = -1
out_j = out_j + 1
elif localization_type == 'rel':
prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
idx = np.arange(griding_num) + 1
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
out_j = np.argmax(out_j, axis=0)
loc[out_j == griding_num] = 0
out_j = loc
else:
raise NotImplementedError
name = names[j]
line_save_path = os.path.join(output_path, name[:-3] + 'lines.txt')
save_dir, _ = os.path.split(line_save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with open(line_save_path, 'w') as fp:
for i in range(out_j.shape[1]):
if np.sum(out_j[:, i] != 0) > 2:
for k in range(out_j.shape[0]):
if out_j[k, i] > 0:
fp.write(
'%d %d ' % (int(out_j[k, i] * col_sample_w * 1640 / 800) - 1, int(590 - k * 20) - 1))
fp.write('\n')
def run_test(net, data_root, exp_name, work_dir, griding_num, use_aux, distributed, batch_size=8, test_list=None):
# torch.backends.cudnn.benchmark = True
output_path = os.path.join(work_dir, exp_name)
if not os.path.exists(output_path) and is_main_process():
os.mkdir(output_path)
synchronize()
loader = get_test_loader(batch_size, data_root, 'CULane', distributed, test_list=test_list)
# import pdb;pdb.set_trace()
for i, data in enumerate(dist_tqdm(loader)):
imgs, names = data
imgs = imgs.cuda()
with torch.no_grad():
out = net(imgs)
if len(out) == 2 and use_aux:
out, seg_out = out
generate_lines(out,imgs[0,0].shape,names,output_path,griding_num,localization_type = 'rel',flip_updown = True)
def generate_tusimple_lines(out,shape,griding_num,localization_type='rel'):
out = out.data.cpu().numpy()
out_loc = np.argmax(out,axis=0)
if localization_type == 'rel':
prob = scipy.special.softmax(out[:-1, :, :], axis=0)
idx = np.arange(griding_num)
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
loc[out_loc == griding_num] = griding_num
out_loc = loc
lanes = []
for i in range(out_loc.shape[1]):
out_i = out_loc[:,i]
lane = [int(round((loc + 0.5) * 1280.0 / (griding_num - 1))) if loc != griding_num else -2 for loc in out_i]
lanes.append(lane)
return lanes
def run_test_tusimple(net, data_root, work_dir, exp_name, griding_num, use_aux, distributed, batch_size=1, test_list=None):
output_path = os.path.join(work_dir,exp_name+'.%d.txt'% get_rank())
fp = open(output_path,'w')
loader = get_test_loader(batch_size, data_root, 'Tusimple', distributed, test_list=test_list)
for i,data in enumerate(dist_tqdm(loader)):
imgs,names = data
# imgs = imgs.cuda()
imgs = imgs.to(device)
with torch.no_grad():
out = net(imgs)
if len(out) == 2 and use_aux:
out = out[0]
for i,name in enumerate(names):
tmp_dict = {}
tmp_dict['lanes'] = generate_tusimple_lines(out[i],imgs[0,0].shape,griding_num)
tmp_dict['h_samples'] = [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260,
270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420,
430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580,
590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]
tmp_dict['raw_file'] = name
tmp_dict['run_time'] = 10
json_str = json.dumps(tmp_dict)
fp.write(json_str+'\n')
fp.close()
def combine_tusimple_test(work_dir,exp_name):
size = get_world_size()
all_res = []
for i in range(size):
output_path = os.path.join(work_dir,exp_name+'.%d.txt'% i)
with open(output_path, 'r') as fp:
res = fp.readlines()
all_res.extend(res)
names = set()
all_res_no_dup = []
for i, res in enumerate(all_res):
pos = res.find('clips')
name = res[pos:].split('\"')[0]
if name not in names:
names.add(name)
all_res_no_dup.append(res)
output_path = os.path.join(work_dir,exp_name+'.txt')
with open(output_path, 'w') as fp:
fp.writelines(all_res_no_dup)
def eval_lane(net, dataset, data_root, work_dir, griding_num, use_aux, distributed, test_list=None, skip_eval=False):
net.eval()
if dataset == 'CULane':
run_test(net, data_root, 'culane_eval_tmp', work_dir, griding_num, use_aux, distributed, test_list=test_list)
synchronize() # wait for all results
if is_main_process():
res = call_culane_eval(data_root, 'culane_eval_tmp', work_dir)
TP,FP,FN = 0,0,0
for k, v in res.items():
val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
val_tp,val_fp,val_fn = int(v['tp']),int(v['fp']),int(v['fn'])
TP += val_tp
FP += val_fp
FN += val_fn
dist_print(k,val)
P = TP * 1.0/(TP + FP)
R = TP * 1.0/(TP + FN)
F = 2*P*R/(P + R)
dist_print(F)
synchronize()
elif dataset == 'Tusimple':
exp_name = 'tusimple_eval_tmp'
run_test_tusimple(net, data_root, work_dir, exp_name, griding_num, use_aux, distributed, test_list=test_list)
synchronize() # wait for all results
if is_main_process():
pred_path = os.path.join(work_dir, exp_name + '.0.txt')
label_json = os.path.join(data_root, 'test_label.json')
if skip_eval or not os.path.isfile(label_json):
dist_print('skip TuSimple metrics (no test_label.json); predictions:', pred_path)
else:
res = LaneEval.bench_one_submit(pred_path, label_json)
res = json.loads(res)
for r in res:
dist_print(r['name'], r['value'])
synchronize()
def read_helper(path):
lines = open(path, 'r').readlines()[1:]
lines = ' '.join(lines)
values = lines.split(' ')[1::2]
keys = lines.split(' ')[0::2]
keys = [key[:-1] for key in keys]
res = {k : v for k,v in zip(keys,values)}
return res
def call_culane_eval(data_dir, exp_name,output_path):
if data_dir[-1] != '/':
data_dir = data_dir + '/'
detect_dir=os.path.join(output_path,exp_name)+'/'
w_lane=30
iou=0.5; # Set iou to 0.3 or 0.5
im_w=1640
im_h=590
frame=1
list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt')
list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt')
list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt')
list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt')
list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt')
list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt')
list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt')
list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt')
list8 = os.path.join(data_dir,'list/test_split/test8_night.txt')
if not os.path.exists(os.path.join(output_path,'txt')):
os.mkdir(os.path.join(output_path,'txt'))
out0 = os.path.join(output_path,'txt','out0_normal.txt')
out1=os.path.join(output_path,'txt','out1_crowd.txt')
out2=os.path.join(output_path,'txt','out2_hlight.txt')
out3=os.path.join(output_path,'txt','out3_shadow.txt')
out4=os.path.join(output_path,'txt','out4_noline.txt')
out5=os.path.join(output_path,'txt','out5_arrow.txt')
out6=os.path.join(output_path,'txt','out6_curve.txt')
out7=os.path.join(output_path,'txt','out7_cross.txt')
out8=os.path.join(output_path,'txt','out8_night.txt')
eval_cmd = './evaluation/culane/evaluate'
if platform.system() == 'Windows':
eval_cmd = eval_cmd.replace('/', os.sep)
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7))
# print('./evaluate -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8))
res_all = {}
res_all['res_normal'] = read_helper(out0)
res_all['res_crowd']= read_helper(out1)
res_all['res_night']= read_helper(out8)
res_all['res_noline'] = read_helper(out4)
res_all['res_shadow'] = read_helper(out3)
res_all['res_arrow']= read_helper(out5)
res_all['res_hlight'] = read_helper(out2)
res_all['res_curve']= read_helper(out6)
res_all['res_cross']= read_helper(out7)
return res_all