feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code, Docker Compose setup, and vendored dataset scaffolds for clone-and-run. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
from .lane_det_tester import LaneDetTester
|
||||
from .lane_det_trainer import LaneDetTrainer
|
||||
from .seg_tester import SegTester
|
||||
from .seg_trainer import SegTrainer
|
||||
from .lane_det_visualizer import LaneDetDir, LaneDetVideo
|
||||
from .seg_visualizer import SegDir, SegVideo
|
||||
@@ -0,0 +1,299 @@
|
||||
# Define every component in one line
|
||||
# cfg: config file, pure dict
|
||||
# args: command line args from argparse
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
from ..datasets import DATASETS, dict_collate_fn
|
||||
from ..losses import LOSSES
|
||||
from ..lr_schedulers import LR_SCHEDULERS
|
||||
from ..models import MODELS
|
||||
from ..optimizers import OPTIMIZERS
|
||||
from ..transforms import TRANSFORMS
|
||||
from ..ddp_utils import init_distributed_mode, is_main_process
|
||||
from ..common import load_checkpoint
|
||||
|
||||
|
||||
def get_collate_fn(name):
|
||||
if name is not None and name == 'dict_collate_fn':
|
||||
return dict_collate_fn
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sampler(ddp, dataset):
|
||||
if ddp:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
else:
|
||||
sampler = torch.utils.data.RandomSampler(dataset)
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
class BaseRunner(ABC):
|
||||
def __init__(self, cfg):
|
||||
if torch.backends.cudnn.version() < 8000:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.model = MODELS.from_dict(cfg['model'])
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def clean(self, *args, **kwargs):
|
||||
# Cleanups and a hook for after-run messages/ops
|
||||
if hasattr(self, '_cfg') and 'exp_dir' in self._cfg.keys():
|
||||
print('Files saved at: {}.\nTensorboard log at: {}'.format(
|
||||
self._cfg['exp_dir'],
|
||||
os.path.join(self._cfg['save_dir'], 'tb_logs', self._cfg['exp_name'])
|
||||
))
|
||||
|
||||
def get_device_and_move_model(self, *args, **kwargs):
|
||||
device = torch.device('cpu')
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda:0')
|
||||
print(device)
|
||||
self.model.to(device)
|
||||
|
||||
return device
|
||||
|
||||
def load_checkpoint(self, ckpt_filename):
|
||||
# [Possible BC-Break] Get rid of scheduler and optimizer loading
|
||||
if ckpt_filename is not None:
|
||||
load_checkpoint(net=self.model, lr_scheduler=None, optimizer=None, filename=ckpt_filename)
|
||||
|
||||
def get_dataset_statics(self, dataset, map_dataset_statics, exist_ok=False):
|
||||
assert hasattr(self, '_cfg')
|
||||
if map_dataset_statics is not None:
|
||||
for k in map_dataset_statics:
|
||||
if exist_ok and k in self._cfg.keys():
|
||||
continue
|
||||
if isinstance(dataset, str):
|
||||
from utils import datasets
|
||||
attr = getattr(datasets.__dict__[dataset], k)
|
||||
else:
|
||||
attr = getattr(dataset, k)
|
||||
self._cfg[k] = attr
|
||||
|
||||
def init_exp_dir(self, cfg, cfg_prefix=None):
|
||||
# Init work directory and save parsed configs for reference
|
||||
assert hasattr(self, '_cfg')
|
||||
exp_dir = os.path.join(self._cfg['save_dir'], self._cfg['exp_name'])
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self._cfg['exp_dir'] = exp_dir
|
||||
with open(os.path.join(exp_dir, cfg_prefix + '_cfg.json'), 'w') as f:
|
||||
f.write(json.dumps(cfg, indent=4))
|
||||
|
||||
@staticmethod
|
||||
def update_cfg(cfg, updates):
|
||||
# Update by argparse object/dict
|
||||
if not isinstance(updates, dict):
|
||||
updates = vars(updates)
|
||||
return cfg.update(updates)
|
||||
|
||||
@staticmethod
|
||||
def write_mp_log(log_file, content, append=True):
|
||||
# Multi-processing log writing
|
||||
import fcntl
|
||||
with open(log_file, 'a' if append else 'w') as f:
|
||||
# Safe writing with locks
|
||||
fcntl.flock(f, fcntl.LOCK_EX)
|
||||
f.write(content)
|
||||
fcntl.flock(f, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
class BaseTrainer(BaseRunner):
|
||||
def __init__(self, cfg, map_dataset_statics=None):
|
||||
super().__init__(cfg)
|
||||
self._cfg = cfg['train']
|
||||
net_without_ddp, self.device = self.get_device_and_move_model()
|
||||
if 'val_num_steps' in self._cfg.keys():
|
||||
self._cfg['validation'] = self._cfg['val_num_steps'] > 0
|
||||
self.init_exp_dir(cfg, 'train')
|
||||
self.writer = self.get_writer()
|
||||
self.load_checkpoint(self._cfg['checkpoint'])
|
||||
|
||||
# Dataset
|
||||
self.collate_fn = get_collate_fn(self._cfg['collate_fn'])
|
||||
transforms = TRANSFORMS.from_dict(cfg['train_augmentation'])
|
||||
dataset = DATASETS.from_dict(cfg['dataset'],
|
||||
transforms=transforms)
|
||||
self.get_dataset_statics(dataset, map_dataset_statics)
|
||||
self.train_sampler = get_sampler(self._cfg['distributed'], dataset)
|
||||
self.dataloader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=self._cfg['batch_size'],
|
||||
collate_fn=self.collate_fn,
|
||||
sampler=self.train_sampler,
|
||||
num_workers=self._cfg['workers'])
|
||||
validation_set = self.get_validation_dataset(cfg)
|
||||
self.validation_loader = None
|
||||
if validation_set is not None:
|
||||
val_bs = self._cfg.get('val_batch_size')
|
||||
if val_bs is None:
|
||||
val_bs = self._cfg['batch_size']
|
||||
self.validation_loader = torch.utils.data.DataLoader(dataset=validation_set,
|
||||
batch_size=val_bs,
|
||||
num_workers=self._cfg['workers'],
|
||||
shuffle=False,
|
||||
collate_fn=self.collate_fn)
|
||||
|
||||
# Optimizer, LR scheduler, etc.
|
||||
self.optimizer = self.get_optimizer(cfg['optimizer'], net_without_ddp)
|
||||
self.lr_scheduler = LR_SCHEDULERS.from_dict(cfg['lr_scheduler'],
|
||||
optimizer=self.optimizer,
|
||||
len_loader=len(self.dataloader))
|
||||
self.criterion = LOSSES.from_dict(cfg['loss'])
|
||||
|
||||
def get_device_and_move_model(self):
|
||||
init_distributed_mode(self._cfg)
|
||||
device = torch.device(self._cfg['device'])
|
||||
print(device)
|
||||
self.model.to(device)
|
||||
|
||||
if self._cfg['distributed']:
|
||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||
net_without_ddp = self.model
|
||||
if self._cfg['distributed']:
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model,
|
||||
device_ids=[self._cfg['gpu']],
|
||||
find_unused_parameters=True)
|
||||
net_without_ddp = self.model.module
|
||||
|
||||
return net_without_ddp, device
|
||||
|
||||
def get_writer(self):
|
||||
return SummaryWriter(os.path.join(self._cfg['save_dir'],
|
||||
'tb_logs',
|
||||
self._cfg['exp_name'])) if is_main_process() else None
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer(optimizer_cfg, net):
|
||||
parameters = optimizer_cfg.pop('parameters') if 'parameters' in optimizer_cfg.keys() else None
|
||||
if parameters is None: # For BC
|
||||
parameters = net.parameters()
|
||||
else: # replace str with actual parameter groups
|
||||
group_keys = [d['params'] for d in parameters if d['params'] != '__others__']
|
||||
for i in range(len(parameters)):
|
||||
assert type(parameters[i]['params']) == str, 'Use string as placeholder in your config!'
|
||||
if parameters[i]['params'] == '__others__':
|
||||
other_params = [v for _, v in list(filter(lambda kv: all([group_key not in kv[0]
|
||||
for group_key in group_keys]),
|
||||
net.named_parameters()))]
|
||||
parameters[i] = {'params': other_params}
|
||||
else:
|
||||
parameters[i]['params'] = [v for _, v in list(filter(lambda kv: parameters[i]['params'] in kv[0],
|
||||
net.named_parameters()))]
|
||||
|
||||
return OPTIMIZERS.from_dict(optimizer_cfg, parameters=parameters)
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_validation_dataset(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class BaseTester(BaseRunner):
|
||||
image_sets = ['val']
|
||||
|
||||
def __init__(self, cfg, map_dataset_statics=None):
|
||||
super().__init__(cfg)
|
||||
self._cfg = cfg['test']
|
||||
self.init_exp_dir(cfg, self.image_sets[self._cfg['state'] - 1])
|
||||
self.device = self.get_device_and_move_model()
|
||||
self.load_checkpoint(self._cfg['checkpoint'])
|
||||
|
||||
# Dataset
|
||||
transforms = TRANSFORMS.from_dict(cfg['test_augmentation'])
|
||||
dataset = DATASETS.from_dict(cfg['test_dataset'] if 'test_dataset' in cfg.keys() else cfg['dataset'],
|
||||
image_set=self.image_sets[self._cfg['state'] - 1],
|
||||
transforms=transforms)
|
||||
self.get_dataset_statics(dataset, map_dataset_statics)
|
||||
|
||||
# Dataloader
|
||||
collate_fn = get_collate_fn(self._cfg['collate_fn'])
|
||||
self.dataloader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=self._cfg['batch_size'],
|
||||
collate_fn=collate_fn,
|
||||
num_workers=self._cfg['workers'],
|
||||
shuffle=False)
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class BaseVisualizer(BaseRunner):
|
||||
dataset_tensor_statistics = []
|
||||
dataset_statistics = []
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self._cfg = cfg['vis'] if 'vis' in cfg.keys() else cfg['test']
|
||||
self.dataloader, dataset = self.get_loader(cfg)
|
||||
self.get_dataset_statics(dataset, set(self.dataset_statistics).union(set(self.dataset_tensor_statistics)),
|
||||
exist_ok=True)
|
||||
for k in self.dataset_tensor_statistics:
|
||||
self._cfg[k] = torch.tensor(self._cfg[k])
|
||||
if self._cfg['pred']:
|
||||
self.device = self.get_device_and_move_model()
|
||||
self.load_checkpoint(self._cfg['checkpoint'])
|
||||
for k in self.dataset_tensor_statistics:
|
||||
self._cfg[k] = self._cfg[k].to(self.device)
|
||||
try:
|
||||
self.model.eval(profiling=True)
|
||||
except TypeError:
|
||||
self.model.eval()
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loader(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class BaseVideoVisualizer(BaseVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.writer = cv2.VideoWriter(self._cfg['save_path'],
|
||||
cv2.VideoWriter_fourcc(*'XVID'),
|
||||
self.dataloader.fps,
|
||||
self.dataloader.resolution)
|
||||
|
||||
def get_loader(self, cfg):
|
||||
if 'vis_dataset' in cfg.keys():
|
||||
dataloader_cfg = cfg['vis_dataset']
|
||||
else:
|
||||
dataloader_cfg = dict(
|
||||
name='VideoLoader',
|
||||
filename=self._cfg['video_path'],
|
||||
batch_size=self._cfg['batch_size']
|
||||
)
|
||||
dataloader = DATASETS.from_dict(dataloader_cfg,
|
||||
transforms=TRANSFORMS.from_dict(cfg['test_augmentation']))
|
||||
|
||||
return dataloader, cfg['dataset']['name']
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
self.writer.release()
|
||||
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
import torch
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast
|
||||
|
||||
from .base import BaseTester
|
||||
from ..seg_utils import ConfusionMatrix
|
||||
from ..lane_det_utils import lane_as_segmentation_inference
|
||||
|
||||
|
||||
class LaneDetTester(BaseTester):
|
||||
image_sets = ['valfast', 'test', 'val']
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.fast_eval = True if self._cfg['state'] == 1 else False
|
||||
|
||||
def run(self):
|
||||
if self.fast_eval:
|
||||
_, x = self.fast_evaluate(self.model, self.device, self.dataloader,
|
||||
self._cfg['mixed_precision'], self._cfg['input_size'], self._cfg['num_classes'])
|
||||
self.write_mp_log('log.txt', self._cfg['exp_name'] + ' validation: ' + str(x) + '\n')
|
||||
else:
|
||||
self.test_one_set(self.model, self.device, self.dataloader, self._cfg['mixed_precision'],
|
||||
[self._cfg['input_size'], self._cfg['original_size']],
|
||||
self._cfg['gap'], self._cfg['ppl'], self._cfg['thresh'],
|
||||
self._cfg['dataset_name'], self._cfg['seg'], self._cfg['max_lane'], self._cfg['exp_name'])
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def test_one_set(net, device, loader, mixed_precision, input_sizes, gap, ppl, thresh, dataset,
|
||||
seg, max_lane=0, exp_name=None):
|
||||
# Adapted from harryhan618/SCNN_Pytorch
|
||||
# Predict on 1 data_loader and save predictions for the official script
|
||||
# sizes: [input size, test original size, ...]
|
||||
# max_lane = 0 -> unlimited number of lanes
|
||||
|
||||
all_lanes = []
|
||||
net.eval()
|
||||
for images, filenames in tqdm(loader):
|
||||
images = images.to(device)
|
||||
with autocast(mixed_precision):
|
||||
if seg:
|
||||
batch_coordinates = lane_as_segmentation_inference(net, images,
|
||||
input_sizes, gap, ppl, thresh, dataset, max_lane)
|
||||
else:
|
||||
batch_coordinates = net.inference(images, input_sizes, gap, ppl, dataset, max_lane)
|
||||
|
||||
# Parse coordinates
|
||||
for j in range(len(batch_coordinates)):
|
||||
lane_coordinates = batch_coordinates[j]
|
||||
if dataset == 'culane':
|
||||
# Save each lane to disk
|
||||
dir_name = filenames[j][:filenames[j].rfind('/')]
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
with open(filenames[j], "w") as f:
|
||||
for lane in lane_coordinates:
|
||||
if lane: # No printing for []
|
||||
for (x, y) in lane:
|
||||
print("{} {}".format(x, y), end=" ", file=f)
|
||||
print(file=f)
|
||||
elif dataset == 'tusimple':
|
||||
# Save lanes to a single file
|
||||
formatted = {
|
||||
"h_samples": [160 + y * 10 for y in range(ppl)],
|
||||
"lanes": [[c[0] for c in lane] for lane in lane_coordinates],
|
||||
"run_time": 0,
|
||||
"raw_file": filenames[j]
|
||||
}
|
||||
all_lanes.append(json.dumps(formatted))
|
||||
elif dataset == 'llamas':
|
||||
# save each lane in images in xxx.lines.txt
|
||||
dir_name = filenames[j][:filenames[j].rfind('/')]
|
||||
file_path = filenames[j].replace("_color_rect", "")
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
with open(file_path, "w") as f:
|
||||
for lane in lane_coordinates:
|
||||
if lane: # No printing for []
|
||||
for (x, y) in lane:
|
||||
print("{} {}".format(x, y), end=" ", file=f)
|
||||
print(file=f)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if dataset == 'tusimple':
|
||||
with open('./output/' + exp_name + '.json', 'w') as f:
|
||||
for lane in all_lanes:
|
||||
print(lane, end="\n", file=f)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def fast_evaluate(net, device, loader, mixed_precision, output_size, num_classes):
|
||||
# Fast evaluation (e.g. on the validation set) by pixel-wise mean IoU
|
||||
net.eval()
|
||||
conf_mat = ConfusionMatrix(num_classes)
|
||||
for image, target in tqdm(loader):
|
||||
image, target = image.to(device), target.to(device)
|
||||
with autocast(mixed_precision):
|
||||
output = net(image)['out']
|
||||
output = torch.nn.functional.interpolate(output, size=output_size,
|
||||
mode='bilinear', align_corners=True)
|
||||
conf_mat.update(target.flatten(), output.argmax(1).flatten())
|
||||
conf_mat.reduce_from_all_processes()
|
||||
|
||||
acc_global, acc, iu = conf_mat.compute()
|
||||
print((
|
||||
'global correct: {:.2f}\n'
|
||||
'average row correct: {}\n'
|
||||
'IoU: {}\n'
|
||||
'mean IoU: {:.2f}'
|
||||
).format(
|
||||
acc_global.item() * 100,
|
||||
['{:.2f}'.format(i) for i in (acc * 100).tolist()],
|
||||
['{:.2f}'.format(i) for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100))
|
||||
|
||||
return acc_global.item() * 100, iu.mean().item() * 100
|
||||
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast, GradScaler
|
||||
|
||||
from ..common import save_checkpoint
|
||||
from ..ddp_utils import reduce_dict, is_main_process
|
||||
from .lane_det_tester import LaneDetTester
|
||||
from .base import BaseTrainer, DATASETS, TRANSFORMS
|
||||
|
||||
|
||||
class LaneDetTrainer(BaseTrainer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def run(self):
|
||||
# Should be the same as segmentation, given customized loss classes
|
||||
self.model.train()
|
||||
epoch = 0
|
||||
running_loss = None # Dict logging for every loss (too many losses in this task)
|
||||
loss_num_steps = int(len(self.dataloader) / 10) if len(self.dataloader) > 10 else 1
|
||||
if self._cfg['mixed_precision']:
|
||||
scaler = GradScaler()
|
||||
|
||||
# Training
|
||||
best_validation = 0
|
||||
while epoch < self._cfg['num_epochs']:
|
||||
self.model.train()
|
||||
if self._cfg['distributed']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
time_now = time.time()
|
||||
for i, data in enumerate(self.dataloader, 0):
|
||||
if self._cfg['seg']:
|
||||
inputs, labels, existence = data
|
||||
inputs, labels, existence = inputs.to(self.device), labels.to(self.device), existence.to(self.device)
|
||||
else:
|
||||
inputs, labels = data
|
||||
inputs = inputs.to(self.device)
|
||||
if self._cfg['collate_fn'] is None:
|
||||
labels = labels.to(self.device)
|
||||
else:
|
||||
labels = [{k: v.to(self.device) for k, v in label.items()} for label in labels] # Seems slow
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with autocast(self._cfg['mixed_precision']):
|
||||
# To support intermediate losses for SAD
|
||||
if self._cfg['seg']:
|
||||
loss, log_dict = self.criterion(inputs, labels, existence,
|
||||
self.model, self._cfg['input_size'])
|
||||
else:
|
||||
loss, log_dict = self.criterion(inputs, labels,
|
||||
self.model)
|
||||
|
||||
if self._cfg['mixed_precision']:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(self.optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
|
||||
log_dict = reduce_dict(log_dict)
|
||||
if running_loss is None: # Because different methods may have different values to log
|
||||
running_loss = {k: 0.0 for k in log_dict.keys()}
|
||||
for k in log_dict.keys():
|
||||
running_loss[k] += log_dict[k]
|
||||
current_step_num = int(epoch * len(self.dataloader) + i + 1)
|
||||
|
||||
# Record losses
|
||||
if current_step_num % loss_num_steps == (loss_num_steps - 1):
|
||||
for k in running_loss.keys():
|
||||
print('[%d, %d] %s: %.4f' % (epoch + 1, i + 1, k, running_loss[k] / loss_num_steps))
|
||||
# Logging only once
|
||||
if is_main_process():
|
||||
self.writer.add_scalar(k, running_loss[k] / loss_num_steps, current_step_num)
|
||||
running_loss[k] = 0.0
|
||||
|
||||
# Record checkpoints
|
||||
if self._cfg['validation']:
|
||||
assert self._cfg['seg'], 'Only segmentation based methods can be fast evaluated!'
|
||||
if current_step_num % self._cfg['val_num_steps'] == (self._cfg['val_num_steps'] - 1) or \
|
||||
current_step_num == self._cfg['num_epochs'] * len(self.dataloader):
|
||||
test_pixel_accuracy, test_mIoU = LaneDetTester.fast_evaluate(
|
||||
loader=self.validation_loader,
|
||||
device=self.device,
|
||||
net=self.model,
|
||||
num_classes=self._cfg['num_classes'],
|
||||
output_size=self._cfg['input_size'],
|
||||
mixed_precision=self._cfg['mixed_precision'])
|
||||
if is_main_process():
|
||||
self.writer.add_scalar('test pixel accuracy',
|
||||
test_pixel_accuracy,
|
||||
current_step_num)
|
||||
self.writer.add_scalar('test mIoU',
|
||||
test_mIoU,
|
||||
current_step_num)
|
||||
self.model.train()
|
||||
|
||||
# Record best model (straight to disk)
|
||||
if test_mIoU > best_validation:
|
||||
best_validation = test_mIoU
|
||||
save_checkpoint(net=self.model.module if self._cfg['distributed'] else self.model,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
filename=os.path.join(self._cfg['exp_dir'], 'model.pt'))
|
||||
|
||||
epoch += 1
|
||||
print('Epoch time: %.2fs' % (time.time() - time_now))
|
||||
|
||||
# For no-evaluation mode
|
||||
if not self._cfg['validation']:
|
||||
save_checkpoint(net=self.model.module if self._cfg['distributed'] else self.model,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
filename=os.path.join(self._cfg['exp_dir'], 'model.pt'))
|
||||
|
||||
def get_validation_dataset(self, cfg):
|
||||
if not self._cfg['validation']:
|
||||
return None
|
||||
validation_transforms = TRANSFORMS.from_dict(cfg['test_augmentation'])
|
||||
validation_set = DATASETS.from_dict(cfg['validation_dataset'],
|
||||
transforms=validation_transforms)
|
||||
return validation_set
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from abc import abstractmethod
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast
|
||||
|
||||
from .base import BaseVisualizer, BaseVideoVisualizer, get_collate_fn
|
||||
from ..datasets import DATASETS
|
||||
from ..transforms import TRANSFORMS
|
||||
from ..lane_det_utils import lane_as_segmentation_inference
|
||||
from ..vis_utils import lane_detection_visualize_batched, save_images
|
||||
|
||||
|
||||
def lane_label_process_fn(label):
|
||||
# The CULane format
|
||||
# input: label txt file path or content as list
|
||||
if isinstance(label, str):
|
||||
with open(label, 'r') as f:
|
||||
label = f.readlines()
|
||||
target = []
|
||||
for line in label:
|
||||
temp = [float(x) for x in line.strip().split(' ')]
|
||||
target.append(np.array(temp).reshape(-1, 2))
|
||||
|
||||
return target
|
||||
|
||||
|
||||
class LaneDetVisualizer(BaseVisualizer):
|
||||
dataset_statistics = ['keypoint_color']
|
||||
color_pool = [[0, 0, 0],
|
||||
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||
[0, 0, 0]]
|
||||
|
||||
@torch.no_grad()
|
||||
def lane_inference(self, images, original_size=None):
|
||||
cps = None # Bézier control points
|
||||
if original_size is None:
|
||||
original_size = self._cfg['original_size']
|
||||
with autocast(self._cfg['mixed_precision']):
|
||||
if self._cfg['seg']: # Seg methods
|
||||
keypoints = lane_as_segmentation_inference(self.model, images,
|
||||
[self._cfg['input_size'], original_size],
|
||||
self._cfg['gap'],
|
||||
self._cfg['ppl'],
|
||||
self._cfg['thresh'],
|
||||
self._cfg['dataset_name'],
|
||||
self._cfg['max_lane'])
|
||||
else:
|
||||
return_cps = self._cfg['style'] == 'bezier'
|
||||
res = self.model.inference(images,
|
||||
[self._cfg['input_size'], original_size],
|
||||
self._cfg['gap'],
|
||||
self._cfg['ppl'],
|
||||
self._cfg['dataset_name'],
|
||||
self._cfg['max_lane'],
|
||||
return_cps=return_cps)
|
||||
if return_cps:
|
||||
cps, keypoints = res
|
||||
else:
|
||||
keypoints = res
|
||||
|
||||
return cps, [[np.array(lane) for lane in image] for image in keypoints]
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loader(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class LaneDetDir(LaneDetVisualizer):
|
||||
dataset_tensor_statistics = ['colors']
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
if self._cfg['save_path'] is not None:
|
||||
os.makedirs(self._cfg['save_path'], exist_ok=True)
|
||||
if self._cfg['use_color_pool']:
|
||||
self._cfg['colors'] = torch.tensor(self.color_pool,
|
||||
dtype=self._cfg['colors'].dtype,
|
||||
device=self._cfg['colors'].device)
|
||||
|
||||
def get_loader(self, cfg):
|
||||
if 'vis_dataset' in cfg.keys():
|
||||
dataset_cfg = cfg['vis_dataset']
|
||||
else:
|
||||
dataset_cfg = dict(
|
||||
name='ImageFolderLaneDataset',
|
||||
root_image=self._cfg['image_path'],
|
||||
root_keypoint=self._cfg['keypoint_path'],
|
||||
root_gt_keypoint=self._cfg['gt_keypoint_path'],
|
||||
root_mask=self._cfg['mask_path'],
|
||||
root_output=self._cfg['save_path'],
|
||||
image_suffix=self._cfg['image_suffix'],
|
||||
keypoint_suffix=self._cfg['keypoint_suffix'],
|
||||
gt_keypoint_suffix=self._cfg['gt_keypoint_suffix'],
|
||||
mask_suffix=self._cfg['mask_suffix']
|
||||
)
|
||||
dataset = DATASETS.from_dict(dataset_cfg,
|
||||
transforms=TRANSFORMS.from_dict(cfg['test_augmentation']),
|
||||
keypoint_process_fn=lane_label_process_fn)
|
||||
collate_fn = get_collate_fn('dict_collate_fn') # Use dicts for customized target
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=self._cfg['batch_size'],
|
||||
collate_fn=collate_fn,
|
||||
num_workers=self._cfg['workers'],
|
||||
shuffle=False)
|
||||
|
||||
return dataloader, cfg['dataset']['name']
|
||||
|
||||
def run(self):
|
||||
for imgs, original_imgs, targets in tqdm(self.dataloader):
|
||||
filenames = [i['filename'] for i in targets]
|
||||
keypoints = [i['keypoint'] for i in targets]
|
||||
gt_keypoints = [i['gt_keypoint'] for i in targets]
|
||||
masks = [i['mask'] for i in targets]
|
||||
cps = None
|
||||
if keypoints.count(None) == len(keypoints):
|
||||
keypoints = None
|
||||
if gt_keypoints.count(None) == len(gt_keypoints):
|
||||
gt_keypoints = None
|
||||
if masks.count(None) == len(masks):
|
||||
masks = None
|
||||
else:
|
||||
masks = torch.stack(masks)
|
||||
if self._cfg['pred']: # Inference keypoints
|
||||
if masks is not None:
|
||||
masks = masks.to(self.device)
|
||||
imgs = imgs.to(self.device)
|
||||
original_imgs = original_imgs.to(self.device)
|
||||
cps, keypoints = self.lane_inference(imgs, original_imgs.shape[2:])
|
||||
results = lane_detection_visualize_batched(original_imgs,
|
||||
masks=masks,
|
||||
keypoints=keypoints,
|
||||
control_points=cps,
|
||||
gt_keypoints=gt_keypoints,
|
||||
mask_colors=self._cfg['colors'],
|
||||
keypoint_color=self._cfg['keypoint_color'],
|
||||
std=None, mean=None, style=self._cfg['style'],
|
||||
compare_gt_metric=self._cfg['metric'])
|
||||
save_images(results, filenames=filenames)
|
||||
|
||||
|
||||
class LaneDetVideo(BaseVideoVisualizer, LaneDetVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def run(self):
|
||||
# Must do inference
|
||||
for imgs, original_imgs in tqdm(self.dataloader):
|
||||
keypoints = None
|
||||
cps = None
|
||||
if self._cfg['pred']:
|
||||
imgs = imgs.to(self.device)
|
||||
original_imgs = original_imgs.to(self.device)
|
||||
cps, keypoints = self.lane_inference(imgs, original_imgs.shape[2:])
|
||||
results = lane_detection_visualize_batched(original_imgs,
|
||||
masks=None,
|
||||
keypoints=keypoints,
|
||||
control_points=cps,
|
||||
mask_colors=None,
|
||||
keypoint_color=self._cfg['keypoint_color'],
|
||||
std=None, mean=None, style=self._cfg['style'])
|
||||
results = results[..., [2, 1, 0]]
|
||||
for j in range(results.shape[0]):
|
||||
self.writer.write(results[j])
|
||||
|
||||
|
||||
class LaneDetDataset(BaseVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def get_loader(self, cfg):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast
|
||||
|
||||
from ..seg_utils import ConfusionMatrix
|
||||
from .base import BaseTester
|
||||
|
||||
|
||||
class SegTester(BaseTester):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg, map_dataset_statics=['categories'])
|
||||
|
||||
def run(self):
|
||||
acc, iou, res_str = self.test_one_set(self.dataloader, self.device, self.model,
|
||||
self._cfg['num_classes'], self._cfg['categories'],
|
||||
self._cfg['original_size'], self._cfg['encoder_size'],
|
||||
self._cfg['mixed_precision'],
|
||||
self._cfg['selector'], self._cfg['eval_classes'],
|
||||
self._cfg['encoder_only'])
|
||||
self.write_mp_log('log.txt', self._cfg['exp_name'] + ': ' + str(iou) + '\n')
|
||||
prefix = 'val' if self._cfg['state'] == 1 else 'custom_state_' + str(self._cfg['state'])
|
||||
with open(os.path.join(self._cfg['exp_dir'], prefix + '_result.txt'), 'w') as f:
|
||||
f.write(res_str)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def test_one_set(loader, device, net, num_classes, categories, output_size, labels_size, mixed_precision,
|
||||
selector=None, classes=None, encoder_only=False):
|
||||
# Copied and modified from torch/vision/references/segmentation
|
||||
# Evaluate on 1 data_loader
|
||||
# Use selector & classes to select part of the classes as metric (for SYNTHIA)
|
||||
net.eval()
|
||||
conf_mat = ConfusionMatrix(num_classes)
|
||||
for image, target in tqdm(loader):
|
||||
image, target = image.to(device), target.to(device)
|
||||
with autocast(mixed_precision):
|
||||
output = net(image)['out']
|
||||
if encoder_only:
|
||||
target = target.unsqueeze(0)
|
||||
if target.dtype not in (torch.float32, torch.float64):
|
||||
target = target.to(torch.float32)
|
||||
target = torch.nn.functional.interpolate(target, size=labels_size, mode='nearest')
|
||||
target = target.to(torch.int64)
|
||||
target = target.squeeze(0)
|
||||
else:
|
||||
output = torch.nn.functional.interpolate(output, size=output_size, mode='bilinear',
|
||||
align_corners=True)
|
||||
conf_mat.update(target.flatten(), output.argmax(1).flatten())
|
||||
conf_mat.reduce_from_all_processes()
|
||||
|
||||
acc_global, acc, iu = conf_mat.compute()
|
||||
res_str = (
|
||||
'All classes: {}\n'
|
||||
'Pixel acc: {:.2f}\n'
|
||||
'Pixel acc (per-class): {}\n'
|
||||
'IoU (per-class): {}\n'
|
||||
'Mean IoU: {:.2f}\n'
|
||||
'Mean IoU-{}: {:.2f}').format(
|
||||
categories,
|
||||
acc_global.item() * 100,
|
||||
['{:.2f}'.format(i) for i in (acc * 100).tolist()],
|
||||
['{:.2f}'.format(i) for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100,
|
||||
-1 if classes is None else classes,
|
||||
-1 if selector is None else iu[selector].mean().item() * 100)
|
||||
print(res_str)
|
||||
|
||||
if selector is None:
|
||||
iou = iu.mean().item() * 100
|
||||
else:
|
||||
iou = iu[selector].mean().item() * 100
|
||||
|
||||
return acc_global.item() * 100, iou, res_str
|
||||
@@ -0,0 +1,150 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast, GradScaler
|
||||
|
||||
from ..common import save_checkpoint
|
||||
from ..seg_utils import ConfusionMatrix
|
||||
from ..ddp_utils import is_main_process, is_dist_avail_and_initialized, get_world_size
|
||||
from .base import BaseTrainer, DATASETS, TRANSFORMS
|
||||
from .seg_tester import SegTester
|
||||
|
||||
|
||||
class SegTrainer(BaseTrainer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg, map_dataset_statics=['categories'])
|
||||
|
||||
def run(self):
|
||||
# Validate and find the best snapshot
|
||||
best_mIoU = 0
|
||||
self.model.train()
|
||||
epoch = 0
|
||||
running_loss = 0.0
|
||||
loss_num_steps = int(len(self.dataloader) / 10)
|
||||
if self._cfg['mixed_precision']:
|
||||
scaler = GradScaler()
|
||||
|
||||
# Training
|
||||
while epoch < self._cfg['num_epochs']:
|
||||
self.model.train()
|
||||
if self._cfg['distributed']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
conf_mat = ConfusionMatrix(self._cfg['num_classes'])
|
||||
time_now = time.time()
|
||||
for i, data in enumerate(self.dataloader, 0):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with autocast(self._cfg['mixed_precision']):
|
||||
outputs = self.model(inputs)['out']
|
||||
|
||||
if self._cfg['encoder_only']:
|
||||
labels = labels.unsqueeze(0)
|
||||
if labels.dtype not in (torch.float32, torch.float64):
|
||||
labels = labels.to(torch.float32)
|
||||
labels = torch.nn.functional.interpolate(labels, size=self._cfg['encoder_size'], mode='nearest')
|
||||
labels = labels.to(torch.int64)
|
||||
labels = labels.squeeze(0)
|
||||
else:
|
||||
outputs = torch.nn.functional.interpolate(outputs, size=self._cfg['input_size'],
|
||||
mode='bilinear', align_corners=True)
|
||||
conf_mat.update(labels.flatten(), outputs.argmax(1).flatten())
|
||||
loss = self.criterion(outputs, labels)
|
||||
|
||||
if self._cfg['mixed_precision']:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(self.optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
running_loss += loss.item()
|
||||
running_loss = torch.tensor([running_loss], dtype=loss.dtype, device=loss.device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(running_loss)
|
||||
running_loss = (running_loss / get_world_size()).item()
|
||||
current_step_num = int(epoch * len(self.dataloader) + i + 1)
|
||||
|
||||
if current_step_num % loss_num_steps == (loss_num_steps - 1):
|
||||
print('[%d, %d] loss: %.4f' % (epoch + 1, i + 1, running_loss / loss_num_steps))
|
||||
if is_main_process():
|
||||
self.writer.add_scalar('training loss',
|
||||
running_loss / loss_num_steps,
|
||||
current_step_num)
|
||||
running_loss = 0.0
|
||||
|
||||
# Validate and find the best snapshot
|
||||
if current_step_num % self._cfg['val_num_steps'] == (self._cfg['val_num_steps'] - 1):
|
||||
test_pixel_accuracy, test_mIoU, _ = SegTester.test_one_set(
|
||||
loader=self.validation_loader, device=self.device, net=self.model,
|
||||
num_classes=self._cfg['num_classes'], categories=self._cfg['categories'],
|
||||
output_size=self._cfg['original_size'],
|
||||
labels_size=self._cfg['encoder_size'],
|
||||
selector=self._cfg['selector'],
|
||||
classes=self._cfg['eval_classes'],
|
||||
mixed_precision=self._cfg['mixed_precision'],
|
||||
encoder_only=self._cfg['encoder_only'])
|
||||
if is_main_process():
|
||||
self.writer.add_scalar('test pixel accuracy',
|
||||
test_pixel_accuracy,
|
||||
current_step_num)
|
||||
self.writer.add_scalar('test mIoU',
|
||||
test_mIoU,
|
||||
current_step_num)
|
||||
self.model.train()
|
||||
|
||||
# Record best model (straight to disk)
|
||||
if test_mIoU > best_mIoU:
|
||||
best_mIoU = test_mIoU
|
||||
save_checkpoint(net=self.model.module if self._cfg['distributed'] else self.model,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
filename=os.path.join(self._cfg['exp_dir'], 'model.pt'))
|
||||
|
||||
# Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
|
||||
conf_mat.reduce_from_all_processes()
|
||||
acc_global, acc, iu = conf_mat.compute()
|
||||
print(self._cfg['categories'])
|
||||
print((
|
||||
'Pixel acc: {:.2f}\n'
|
||||
'Pixel acc (per-class): {}\n'
|
||||
'IoU (per-class): {}\n'
|
||||
'Mean IoU: {:.2f}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.2f}'.format(i) for i in (acc * 100).tolist()],
|
||||
['{:.2f}'.format(i) for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100))
|
||||
|
||||
train_pixel_acc = acc_global.item() * 100
|
||||
train_mIoU = iu.mean().item() * 100
|
||||
if is_main_process():
|
||||
self.writer.add_scalar('train pixel accuracy',
|
||||
train_pixel_acc,
|
||||
epoch + 1)
|
||||
self.writer.add_scalar('train mIoU',
|
||||
train_mIoU,
|
||||
epoch + 1)
|
||||
|
||||
epoch += 1
|
||||
print('Epoch time: %.2fs' % (time.time() - time_now))
|
||||
|
||||
def get_validation_dataset(self, cfg):
|
||||
if not self._cfg['validation']:
|
||||
return None
|
||||
validation_transforms = TRANSFORMS.from_dict(cfg['test_augmentation'])
|
||||
validation_set = DATASETS.from_dict(cfg['test_dataset'] if 'test_dataset' in cfg.keys() else cfg['dataset'],
|
||||
image_set='val',
|
||||
transforms=validation_transforms)
|
||||
|
||||
return validation_set
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
if is_main_process():
|
||||
print('Segmentation models used to be evaluated upon training, now please run a separate --val for eval!')
|
||||
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from abc import abstractmethod
|
||||
from PIL import Image
|
||||
if torch.__version__ >= '1.6.0':
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
from ..torch_amp_dummy import autocast
|
||||
|
||||
from .base import BaseVisualizer, BaseVideoVisualizer, get_collate_fn
|
||||
from ..datasets import DATASETS
|
||||
from ..transforms import TRANSFORMS, ToTensor, functional as F
|
||||
from ..vis_utils import segmentation_visualize_batched, save_images, \
|
||||
find_transform_by_name, get_transform_attr_by_name, tensor_image_to_numpy
|
||||
|
||||
|
||||
def seg_label_process_fn(label):
|
||||
# Open and process a common seg label from filename
|
||||
label = Image.open(label)
|
||||
label = ToTensor.label_to_tensor(label)
|
||||
|
||||
return label
|
||||
|
||||
|
||||
class SegVisualizer(BaseVisualizer):
|
||||
dataset_tensor_statistics = ['colors']
|
||||
|
||||
@torch.no_grad()
|
||||
def seg_inference(self, images, original_size=None, pad_crop=False):
|
||||
# Segmentation methods have simple and unified output formats,
|
||||
# same simple post-process will suffice
|
||||
if original_size is None:
|
||||
original_size = self._cfg['original_size']
|
||||
with autocast(self._cfg['mixed_precision']):
|
||||
results = self.model(images)['out']
|
||||
if pad_crop: # VOC style transform
|
||||
labels = torch.nn.functional.interpolate(results, size=images.shape[2:],
|
||||
mode='bilinear', align_corners=True)
|
||||
labels = F.crop(labels, 0, 0, original_size[0], original_size[1])
|
||||
else:
|
||||
labels = torch.nn.functional.interpolate(results, size=original_size,
|
||||
mode='bilinear', align_corners=True)
|
||||
return labels.argmax(1)
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_loader(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class SegDir(SegVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
os.makedirs(self._cfg['save_path'], exist_ok=True)
|
||||
self.pad_crop = find_transform_by_name(cfg['test_augmentation'], 'ZeroPad')
|
||||
self.id_map = None
|
||||
if self._cfg['map_id']:
|
||||
id_map = get_transform_attr_by_name(cfg['test_augmentation'], 'LabelMap', attr='label_id_map')
|
||||
self.id_map = torch.tensor(id_map)
|
||||
|
||||
def get_loader(self, cfg):
|
||||
if 'vis_dataset' in cfg.keys():
|
||||
dataset_cfg = cfg['vis_dataset']
|
||||
else:
|
||||
dataset_cfg = dict(
|
||||
name='ImageFolderDataset',
|
||||
root_image=self._cfg['image_path'],
|
||||
root_target=self._cfg['target_path'],
|
||||
root_output=self._cfg['save_path'],
|
||||
image_suffix=self._cfg['image_suffix'],
|
||||
target_suffix=self._cfg['target_suffix']
|
||||
)
|
||||
dataset = DATASETS.from_dict(dataset_cfg,
|
||||
transforms=TRANSFORMS.from_dict(cfg['test_augmentation']),
|
||||
target_process_fn=seg_label_process_fn)
|
||||
|
||||
collate_fn = get_collate_fn('dict_collate_fn') # Use dicts for customized target
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=self._cfg['batch_size'],
|
||||
collate_fn=collate_fn,
|
||||
num_workers=self._cfg['workers'],
|
||||
shuffle=False)
|
||||
|
||||
return dataloader, cfg['dataset']['name']
|
||||
|
||||
def run(self):
|
||||
for imgs, original_imgs, targets in tqdm(self.dataloader):
|
||||
filenames = [i['filename'] for i in targets]
|
||||
targets = [i['target'] for i in targets]
|
||||
if targets.count(None) == len(targets):
|
||||
targets = None
|
||||
else:
|
||||
targets = torch.stack(targets)
|
||||
if self._cfg['pred']: # Inference
|
||||
imgs = imgs.to(self.device)
|
||||
original_imgs = original_imgs.to(self.device)
|
||||
targets = self.seg_inference(imgs, original_imgs.shape[2:], pad_crop=self.pad_crop)
|
||||
elif self.id_map is not None:
|
||||
targets[targets >= self.id_map.shape[0]] = 0
|
||||
targets = self.id_map[targets]
|
||||
results = segmentation_visualize_batched(original_imgs,
|
||||
targets,
|
||||
colors=self._cfg['colors'],
|
||||
std=None, mean=None)
|
||||
save_images(results, filenames=filenames)
|
||||
|
||||
|
||||
class SegVideo(BaseVideoVisualizer, SegVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.pad_crop = find_transform_by_name(cfg['test_augmentation'], 'ZeroPad')
|
||||
|
||||
def run(self):
|
||||
# Must do inference
|
||||
for imgs, original_imgs in tqdm(self.dataloader):
|
||||
targets = None
|
||||
if self._cfg['pred']: # Inference
|
||||
imgs = imgs.to(self.device)
|
||||
original_imgs = original_imgs.to(self.device)
|
||||
targets = self.seg_inference(imgs, original_imgs.shape[2:], pad_crop=self.pad_crop)
|
||||
results = segmentation_visualize_batched(original_imgs,
|
||||
targets,
|
||||
colors=self._cfg['colors'],
|
||||
std=None, mean=None)
|
||||
np_results = tensor_image_to_numpy(results)[..., [2, 1, 0]]
|
||||
for j in range(np_results.shape[0]):
|
||||
self.writer.write(np_results[j])
|
||||
|
||||
|
||||
class SegDataset(SegVisualizer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def get_loader(self, cfg):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user