import os import os.path as osp import numpy as np from .base_dataset import BaseDataset from .registry import DATASETS import clrnet.utils.culane_metric as culane_metric import cv2 from tqdm import tqdm import logging import pickle as pkl LIST_FILE = { 'train': 'list/train_gt.txt', 'val': 'list/val.txt', 'test': 'list/test.txt', } CATEGORYS = { 'normal': 'list/test_split/test0_normal.txt', 'crowd': 'list/test_split/test1_crowd.txt', 'hlight': 'list/test_split/test2_hlight.txt', 'shadow': 'list/test_split/test3_shadow.txt', 'noline': 'list/test_split/test4_noline.txt', 'arrow': 'list/test_split/test5_arrow.txt', 'curve': 'list/test_split/test6_curve.txt', 'cross': 'list/test_split/test7_cross.txt', 'night': 'list/test_split/test8_night.txt', } @DATASETS.register_module class CULane(BaseDataset): def __init__(self, data_root, split, processes=None, cfg=None): super().__init__(data_root, split, processes=processes, cfg=cfg) self.list_path = osp.join(data_root, LIST_FILE[split]) self.split = split self.load_annotations() def load_annotations(self): self.logger.info('Loading CULane annotations...') # Waiting for the dataset to load is tedious, let's cache it os.makedirs('cache', exist_ok=True) cache_path = 'cache/culane_{}.pkl'.format(self.split) if os.path.exists(cache_path): with open(cache_path, 'rb') as cache_file: self.data_infos = pkl.load(cache_file) self.max_lanes = max( len(anno['lanes']) for anno in self.data_infos) return self.data_infos = [] with open(self.list_path) as list_file: for line in list_file: infos = self.load_annotation(line.split()) self.data_infos.append(infos) # cache data infos to file with open(cache_path, 'wb') as cache_file: pkl.dump(self.data_infos, cache_file) def load_annotation(self, line): infos = {} img_line = line[0] img_line = img_line[1 if img_line[0] == '/' else 0::] img_path = os.path.join(self.data_root, img_line) infos['img_name'] = img_line infos['img_path'] = img_path if len(line) > 1: mask_line = line[1] mask_line = mask_line[1 if mask_line[0] == '/' else 0::] mask_path = os.path.join(self.data_root, mask_line) infos['mask_path'] = mask_path if len(line) > 2: exist_list = [int(l) for l in line[2:]] infos['lane_exist'] = np.array(exist_list) anno_path = img_path[:-3] + 'lines.txt' # remove sufix jpg and add lines.txt with open(anno_path, 'r') as anno_file: data = [ list(map(float, line.split())) for line in anno_file.readlines() ] lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data] lanes = [list(set(lane)) for lane in lanes] # remove duplicated points lanes = [lane for lane in lanes if len(lane) > 2] # remove lanes with less than 2 points lanes = [sorted(lane, key=lambda x: x[1]) for lane in lanes] # sort by y infos['lanes'] = lanes return infos def get_prediction_string(self, pred): ys = np.arange(270, 590, 8) / self.cfg.ori_img_h out = [] for lane in pred: xs = lane(ys) valid_mask = (xs >= 0) & (xs < 1) xs = xs * self.cfg.ori_img_w lane_xs = xs[valid_mask] lane_ys = ys[valid_mask] * self.cfg.ori_img_h lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1] lane_str = ' '.join([ '{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys) ]) if lane_str != '': out.append(lane_str) return '\n'.join(out) def evaluate(self, predictions, output_basedir): loss_lines = [[], [], [], []] print('Generating prediction output...') for idx, pred in enumerate(predictions): output_dir = os.path.join( output_basedir, os.path.dirname(self.data_infos[idx]['img_name'])) output_filename = os.path.basename( self.data_infos[idx]['img_name'])[:-3] + 'lines.txt' os.makedirs(output_dir, exist_ok=True) output = self.get_prediction_string(pred) with open(os.path.join(output_dir, output_filename), 'w') as out_file: out_file.write(output) for cate, cate_file in CATEGORYS.items(): result = culane_metric.eval_predictions(output_basedir, self.data_root, os.path.join(self.data_root, cate_file), iou_thresholds=[0.5], official=True) result = culane_metric.eval_predictions(output_basedir, self.data_root, self.list_path, iou_thresholds=np.linspace(0.5, 0.95, 10), official=True) return result[0.5]['F1']