181 lines
7.0 KiB
Python
181 lines
7.0 KiB
Python
|
|
import os
|
||
|
|
import pickle as pkl
|
||
|
|
import cv2
|
||
|
|
|
||
|
|
from .registry import DATASETS
|
||
|
|
import numpy as np
|
||
|
|
from tqdm import tqdm
|
||
|
|
from .base_dataset import BaseDataset
|
||
|
|
|
||
|
|
TRAIN_LABELS_DIR = 'labels/train'
|
||
|
|
TEST_LABELS_DIR = 'labels/valid'
|
||
|
|
TEST_IMGS_DIR = 'color_images/test'
|
||
|
|
SPLIT_DIRECTORIES = {'train': 'labels/train', 'val': 'labels/valid'}
|
||
|
|
from clrnet.utils.llamas_utils import get_horizontal_values_for_four_lanes
|
||
|
|
import clrnet.utils.llamas_metric as llamas_metric
|
||
|
|
|
||
|
|
|
||
|
|
@DATASETS.register_module
|
||
|
|
class LLAMAS(BaseDataset):
|
||
|
|
def __init__(self, data_root, split='train', processes=None, cfg=None):
|
||
|
|
self.split = split
|
||
|
|
self.data_root = data_root
|
||
|
|
super().__init__(data_root, split, processes, cfg)
|
||
|
|
if split != 'test' and split not in SPLIT_DIRECTORIES.keys():
|
||
|
|
raise Exception('Split `{}` does not exist.'.format(split))
|
||
|
|
if split != 'test':
|
||
|
|
self.labels_dir = os.path.join(self.data_root,
|
||
|
|
SPLIT_DIRECTORIES[split])
|
||
|
|
|
||
|
|
self.data_infos = []
|
||
|
|
self.load_annotations()
|
||
|
|
|
||
|
|
def get_img_heigth(self, _):
|
||
|
|
return self.cfg.ori_img_h
|
||
|
|
|
||
|
|
def get_img_width(self, _):
|
||
|
|
return self.cfg.ori_img_w
|
||
|
|
|
||
|
|
def get_metrics(self, lanes, _):
|
||
|
|
# Placeholders
|
||
|
|
return [0] * len(lanes), [0] * len(lanes), [1] * len(lanes), [
|
||
|
|
1
|
||
|
|
] * len(lanes)
|
||
|
|
|
||
|
|
def get_img_path(self, json_path):
|
||
|
|
# /foo/bar/test/folder/image_label.ext --> test/folder/image_label.ext
|
||
|
|
base_name = '/'.join(json_path.split('/')[-3:])
|
||
|
|
image_path = os.path.join(
|
||
|
|
'color_images', base_name.replace('.json', '_color_rect.png'))
|
||
|
|
return image_path
|
||
|
|
|
||
|
|
def get_img_name(self, json_path):
|
||
|
|
base_name = (json_path.split('/')[-1]).replace('.json',
|
||
|
|
'_color_rect.png')
|
||
|
|
return base_name
|
||
|
|
|
||
|
|
def get_json_paths(self):
|
||
|
|
json_paths = []
|
||
|
|
for root, _, files in os.walk(self.labels_dir):
|
||
|
|
for file in files:
|
||
|
|
if file.endswith(".json"):
|
||
|
|
json_paths.append(os.path.join(root, file))
|
||
|
|
return json_paths
|
||
|
|
|
||
|
|
def load_annotations(self):
|
||
|
|
# the labels are not public for the test set yet
|
||
|
|
if self.split == 'test':
|
||
|
|
imgs_dir = os.path.join(self.data_root, TEST_IMGS_DIR)
|
||
|
|
self.data_infos = [{
|
||
|
|
'img_path':
|
||
|
|
os.path.join(root, file),
|
||
|
|
'img_name':
|
||
|
|
os.path.join(TEST_IMGS_DIR,
|
||
|
|
root.split('/')[-1], file),
|
||
|
|
'lanes': [],
|
||
|
|
'relative_path':
|
||
|
|
os.path.join(root.split('/')[-1], file)
|
||
|
|
} for root, _, files in os.walk(imgs_dir) for file in files
|
||
|
|
if file.endswith('.png')]
|
||
|
|
self.data_infos = sorted(self.data_infos,
|
||
|
|
key=lambda x: x['img_path'])
|
||
|
|
return
|
||
|
|
|
||
|
|
# Waiting for the dataset to load is tedious, let's cache it
|
||
|
|
os.makedirs('cache', exist_ok=True)
|
||
|
|
cache_path = 'cache/llamas_{}.pkl'.format(self.split)
|
||
|
|
if os.path.exists(cache_path):
|
||
|
|
with open(cache_path, 'rb') as cache_file:
|
||
|
|
self.data_infos = pkl.load(cache_file)
|
||
|
|
self.max_lanes = max(
|
||
|
|
len(anno['lanes']) for anno in self.data_infos)
|
||
|
|
return
|
||
|
|
|
||
|
|
self.max_lanes = 0
|
||
|
|
print("Searching annotation files...")
|
||
|
|
json_paths = self.get_json_paths()
|
||
|
|
print('{} annotations found.'.format(len(json_paths)))
|
||
|
|
|
||
|
|
for json_path in tqdm(json_paths):
|
||
|
|
lanes = get_horizontal_values_for_four_lanes(json_path)
|
||
|
|
lanes = [[(x, y) for x, y in zip(lane, range(self.cfg.ori_img_h))
|
||
|
|
if x >= 0] for lane in lanes]
|
||
|
|
lanes = [lane for lane in lanes if len(lane) > 0]
|
||
|
|
lanes = [list(set(lane))
|
||
|
|
for lane in lanes] # remove duplicated points
|
||
|
|
lanes = [lane for lane in lanes
|
||
|
|
if len(lane) > 2] # remove lanes with less than 2 points
|
||
|
|
|
||
|
|
lanes = [sorted(lane, key=lambda x: x[1])
|
||
|
|
for lane in lanes] # sort by y
|
||
|
|
lanes.sort(key=lambda lane: lane[0][0])
|
||
|
|
mask_path = json_path.replace('.json', '.png')
|
||
|
|
|
||
|
|
# generate seg labels
|
||
|
|
seg = np.zeros((717, 1276, 3))
|
||
|
|
for i, lane in enumerate(lanes):
|
||
|
|
for j in range(0, len(lane) - 1):
|
||
|
|
cv2.line(seg, (round(lane[j][0]), lane[j][1]),
|
||
|
|
(round(lane[j + 1][0]), lane[j + 1][1]),
|
||
|
|
(i + 1, i + 1, i + 1),
|
||
|
|
thickness=15)
|
||
|
|
|
||
|
|
cv2.imwrite(mask_path, seg)
|
||
|
|
|
||
|
|
relative_path = self.get_img_path(json_path)
|
||
|
|
img_path = os.path.join(self.data_root, relative_path)
|
||
|
|
self.max_lanes = max(self.max_lanes, len(lanes))
|
||
|
|
self.data_infos.append({
|
||
|
|
'img_path': img_path,
|
||
|
|
'img_name': relative_path,
|
||
|
|
'mask_path': mask_path,
|
||
|
|
'lanes': lanes,
|
||
|
|
'relative_path': relative_path
|
||
|
|
})
|
||
|
|
|
||
|
|
with open(cache_path, 'wb') as cache_file:
|
||
|
|
pkl.dump(self.data_infos, cache_file)
|
||
|
|
|
||
|
|
def assign_class_to_lanes(self, lanes):
|
||
|
|
return {
|
||
|
|
label: value
|
||
|
|
for label, value in zip(['l0', 'l1', 'r0', 'r1'], lanes)
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_prediction_string(self, pred):
|
||
|
|
ys = np.arange(300, 717, 1) / (self.cfg.ori_img_h - 1)
|
||
|
|
out = []
|
||
|
|
for lane in pred:
|
||
|
|
xs = lane(ys)
|
||
|
|
valid_mask = (xs >= 0) & (xs < 1)
|
||
|
|
xs = xs * (self.cfg.ori_img_w - 1)
|
||
|
|
lane_xs = xs[valid_mask]
|
||
|
|
lane_ys = ys[valid_mask] * (self.cfg.ori_img_h - 1)
|
||
|
|
lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1]
|
||
|
|
lane_str = ' '.join([
|
||
|
|
'{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys)
|
||
|
|
])
|
||
|
|
if lane_str != '':
|
||
|
|
out.append(lane_str)
|
||
|
|
|
||
|
|
return '\n'.join(out)
|
||
|
|
|
||
|
|
def evaluate(self, predictions, output_basedir):
|
||
|
|
print('Generating prediction output...')
|
||
|
|
for idx, pred in enumerate(predictions):
|
||
|
|
relative_path = self.data_infos[idx]['relative_path']
|
||
|
|
output_filename = '/'.join(relative_path.split('/')[-2:]).replace(
|
||
|
|
'_color_rect.png', '.lines.txt')
|
||
|
|
output_filepath = os.path.join(output_basedir, output_filename)
|
||
|
|
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
|
||
|
|
output = self.get_prediction_string(pred)
|
||
|
|
with open(output_filepath, 'w') as out_file:
|
||
|
|
out_file.write(output)
|
||
|
|
if self.split == 'test':
|
||
|
|
return None
|
||
|
|
result = llamas_metric.eval_predictions(output_basedir,
|
||
|
|
self.labels_dir,
|
||
|
|
iou_thresholds=np.linspace(0.5, 0.95, 10),
|
||
|
|
unofficial=False)
|
||
|
|
return result[0.5]['F1']
|