feat: initial HSAP platform

Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2026-05-25 16:59:59 +08:00
commit 7c43b44c57
1619 changed files with 373355 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from .backbones import *
from .heads import *
from .nets import *
from .necks import *
from .registry import build_backbones

View File

@@ -0,0 +1,2 @@
from .resnet import ResNet
from .dla34 import DLA

View File

@@ -0,0 +1,460 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import logging
import numpy as np
from os.path import join
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from clrnet.models.registry import BACKBONES
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'):
return join('http://dl.yf.io/dla/models', data,
'{}-{}.pth'.format(name, hash))
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes,
planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=dilation,
bias=False,
dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(Bottleneck, self).__init__()
expansion = Bottleneck.expansion
bottle_planes = planes // expansion
self.conv1 = nn.Conv2d(inplanes,
bottle_planes,
kernel_size=1,
bias=False)
self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation)
self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(bottle_planes,
planes,
kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out
class BottleneckX(nn.Module):
expansion = 2
cardinality = 32
def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BottleneckX, self).__init__()
cardinality = BottleneckX.cardinality
# dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0)))
# bottle_planes = dim * cardinality
bottle_planes = planes * cardinality // 32
self.conv1 = nn.Conv2d(inplanes,
bottle_planes,
kernel_size=1,
bias=False)
self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation,
groups=cardinality)
self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(bottle_planes,
planes,
kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out
class Root(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(in_channels,
out_channels,
1,
stride=1,
bias=False,
padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.residual = residual
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
x += children[0]
x = self.relu(x)
return x
class Tree(nn.Module):
def __init__(self,
levels,
block,
in_channels,
out_channels,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
dilation=1,
root_residual=False):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.tree1 = block(in_channels,
out_channels,
stride,
dilation=dilation)
self.tree2 = block(out_channels,
out_channels,
1,
dilation=dilation)
else:
self.tree1 = Tree(levels - 1,
block,
in_channels,
out_channels,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
self.tree2 = Tree(levels - 1,
block,
out_channels,
out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
if levels == 1:
self.root = Root(root_dim, out_channels, root_kernel_size,
root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
# Match CLRerNet/official DLA: project only on leaf when channels differ.
if levels == 1 and in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
bias=False),
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM))
def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
class DLA(nn.Module):
def __init__(self,
levels,
channels,
num_classes=1000,
block=BasicBlock,
residual_root=False,
linear_root=False):
super(DLA, self).__init__()
self.channels = channels
self.num_classes = num_classes
self.base_layer = nn.Sequential(
nn.Conv2d(3,
channels[0],
kernel_size=7,
stride=1,
padding=3,
bias=False),
nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0],
levels[0])
self.level1 = self._make_conv_level(channels[0],
channels[1],
levels[1],
stride=2)
self.level2 = Tree(levels[2],
block,
channels[1],
channels[2],
2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(levels[3],
block,
channels[2],
channels[3],
2,
level_root=True,
root_residual=residual_root)
self.level4 = Tree(levels[4],
block,
channels[3],
channels[4],
2,
level_root=True,
root_residual=residual_root)
self.level5 = Tree(levels[5],
block,
channels[4],
channels[5],
2,
level_root=True,
root_residual=residual_root)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
def _make_level(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
nn.MaxPool2d(stride, stride=stride),
nn.Conv2d(inplanes,
planes,
kernel_size=1,
stride=1,
bias=False),
nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample=downsample))
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.extend([
nn.Conv2d(inplanes,
planes,
kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=False,
dilation=dilation),
nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
])
inplanes = planes
return nn.Sequential(*modules)
def forward(self, x):
y = []
x = self.base_layer(x)
for i in range(6):
x = getattr(self, 'level{}'.format(i))(x)
y.append(x)
return y[2:]
def load_pretrained_model(self,
data='imagenet',
name='dla34',
hash='ba72cf86'):
# fc = self.fc
if name.endswith('.pth'):
model_weights = torch.load(data + name)
else:
model_url = get_model_url(data, name, hash)
model_weights = model_zoo.load_url(model_url)
self.load_state_dict(model_weights, strict=False)
# self.fc = fc
def dla34(pretrained=True, levels=None, in_channels=None, **kwargs): # DLA-34
model = DLA(levels=levels,
channels=in_channels,
block=BasicBlock,
**kwargs)
if pretrained:
model.load_pretrained_model(data='imagenet',
name='dla34',
hash='ba72cf86')
return model
@BACKBONES.register_module
class DLAWrapper(nn.Module):
def __init__(self,
dla='dla34',
pretrained=True,
levels=[1, 1, 1, 2, 2, 1],
in_channels=[16, 32, 64, 128, 256, 512],
cfg=None):
super(DLAWrapper, self).__init__()
self.cfg = cfg
self.in_channels = in_channels
self.model = eval(dla)(pretrained=pretrained,
levels=levels,
in_channels=in_channels)
def forward(self, x):
x = self.model(x)
return x
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def fill_fc_weights(layers):
for m in layers.modules():
if isinstance(m, nn.Conv2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def fill_up_weights(up):
w = up.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]

View File

@@ -0,0 +1,431 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from clrnet.models.registry import BACKBONES
model_urls = {
'resnet18':
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34':
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50':
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101':
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152':
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d':
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d':
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2':
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2':
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
# if dilation > 1:
# raise NotImplementedError(
# "Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, dilation=dilation)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
@BACKBONES.register_module
class ResNetWrapper(nn.Module):
def __init__(self,
resnet='resnet18',
pretrained=True,
replace_stride_with_dilation=[False, False, False],
out_conv=False,
fea_stride=8,
out_channel=128,
in_channels=[64, 128, 256, 512],
cfg=None):
super(ResNetWrapper, self).__init__()
self.cfg = cfg
self.in_channels = in_channels
self.model = eval(resnet)(
pretrained=pretrained,
replace_stride_with_dilation=replace_stride_with_dilation,
in_channels=self.in_channels)
self.out = None
if out_conv:
out_channel = 512
for chan in reversed(self.in_channels):
if chan < 0: continue
out_channel = chan
break
self.out = conv1x1(out_channel * self.model.expansion,
cfg.featuremap_out_channel)
def forward(self, x):
x = self.model(x)
if self.out:
x[-1] = self.out(x[-1])
return x
class ResNet(nn.Module):
def __init__(self,
block,
layers,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None,
in_channels=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.in_channels = in_channels
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
self.layer2 = self._make_layer(block,
in_channels[1],
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
in_channels[2],
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
if in_channels[3] > 0:
self.layer4 = self._make_layer(
block,
in_channels[3],
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.expansion = block.expansion
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
out_layers = []
for name in ['layer1', 'layer2', 'layer3', 'layer4']:
if not hasattr(self, name):
continue
layer = getattr(self, name)
x = layer(x)
out_layers.append(x)
return out_layers
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
print('pretrained model: ', model_urls[arch])
# state_dict = torch.load(model_urls[arch])['net']
state_dict = load_state_dict_from_url(model_urls[arch])
model.load_state_dict(state_dict, strict=False)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
progress, **kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)

View File

@@ -0,0 +1 @@
from .clr_head import CLRHead

View File

@@ -0,0 +1,492 @@
import math
import cv2
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from clrnet.utils.lane import Lane
from clrnet.models.losses.focal_loss import FocalLoss
from clrnet.models.losses.accuracy import accuracy
from clrnet.ops import nms
from clrnet.models.utils.roi_gather import ROIGather, LinearModule
from clrnet.models.utils.seg_decoder import SegDecoder
from clrnet.models.utils.dynamic_assign import assign
from clrnet.models.losses.lineiou_loss import liou_loss
from clrnet.utils.bilinear_grid_sample import bilinear_grid_sample
from ..registry import HEADS
# Set True for ONNX/RKNN export (Gather+Pad instead of GridSample).
USE_BILINEAR_GRID_SAMPLE = True
@HEADS.register_module
class CLRHead(nn.Module):
def __init__(self,
num_points=72,
prior_feat_channels=64,
fc_hidden_dim=64,
num_priors=192,
num_fc=2,
refine_layers=3,
sample_points=36,
cfg=None):
super(CLRHead, self).__init__()
self.cfg = cfg
self.img_w = self.cfg.img_w
self.img_h = self.cfg.img_h
self.n_strips = num_points - 1
self.n_offsets = num_points
self.num_priors = num_priors
self.sample_points = sample_points
self.refine_layers = refine_layers
self.fc_hidden_dim = fc_hidden_dim
self.register_buffer(name='sample_x_indexs', tensor=(torch.linspace(
0, 1, steps=self.sample_points, dtype=torch.float32) *
self.n_strips).long())
self.register_buffer(name='prior_feat_ys', tensor=torch.flip(
(1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1]))
self.register_buffer(name='prior_ys', tensor=torch.linspace(1,
0,
steps=self.n_offsets,
dtype=torch.float32))
self.prior_feat_channels = prior_feat_channels
self._init_prior_embeddings()
init_priors, priors_on_featmap = self.generate_priors_from_embeddings() #None, None
self.register_buffer(name='priors', tensor=init_priors)
self.register_buffer(name='priors_on_featmap', tensor=priors_on_featmap)
# generate xys for feature map
self.seg_decoder = SegDecoder(self.img_h, self.img_w,
self.cfg.num_classes,
self.prior_feat_channels,
self.refine_layers)
reg_modules = list()
cls_modules = list()
for _ in range(num_fc):
reg_modules += [*LinearModule(self.fc_hidden_dim)]
cls_modules += [*LinearModule(self.fc_hidden_dim)]
self.reg_modules = nn.ModuleList(reg_modules)
self.cls_modules = nn.ModuleList(cls_modules)
self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors,
self.sample_points, self.fc_hidden_dim,
self.refine_layers)
self.reg_layers = nn.Linear(
self.fc_hidden_dim, self.n_offsets + 1 + 2 +
1) # n offsets + 1 length + start_x + start_y + theta
self.cls_layers = nn.Linear(self.fc_hidden_dim, 2)
weights = torch.ones(self.cfg.num_classes)
weights[0] = self.cfg.bg_weight
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
weight=weights)
# init the weights here
self.init_weights()
# function to init layer weights
def init_weights(self):
# initialize heads
for m in self.cls_layers.parameters():
nn.init.normal_(m, mean=0., std=1e-3)
for m in self.reg_layers.parameters():
nn.init.normal_(m, mean=0., std=1e-3)
def pool_prior_features(self, batch_features, num_priors, prior_xs):
'''
pool prior feature from feature map.
Args:
batch_features (Tensor): Input feature maps, shape: (B, C, H, W)
'''
batch_size = batch_features.shape[0]
prior_xs = prior_xs.view(batch_size, num_priors, -1, 1)
prior_ys = self.prior_feat_ys.repeat(batch_size * num_priors).view(
batch_size, num_priors, -1, 1)
prior_xs = prior_xs * 2. - 1.
prior_ys = prior_ys * 2. - 1.
grid = torch.cat((prior_xs, prior_ys), dim=-1)
if USE_BILINEAR_GRID_SAMPLE:
feature = bilinear_grid_sample(batch_features, grid,
align_corners=True).permute(0, 2, 1, 3)
else:
feature = F.grid_sample(batch_features, grid,
align_corners=True).permute(0, 2, 1, 3)
feature = feature.reshape(batch_size * num_priors,
self.prior_feat_channels, self.sample_points,
1)
return feature
def generate_priors_from_embeddings(self):
predictions = self.prior_embeddings.weight # (num_prop, 3)
# 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, score[0] = negative prob, score[1] = positive prob
priors = predictions.new_zeros(
(self.num_priors, 2 + 2 + 2 + self.n_offsets), device=predictions.device)
priors[:, 2:5] = predictions.clone()
priors[:, 6:] = (
priors[:, 3].unsqueeze(1).clone().repeat(1, self.n_offsets) *
(self.img_w - 1) +
((1 - self.prior_ys.repeat(self.num_priors, 1) -
priors[:, 2].unsqueeze(1).clone().repeat(1, self.n_offsets)) *
self.img_h / torch.tan(priors[:, 4].unsqueeze(1).clone().repeat(
1, self.n_offsets) * math.pi + 1e-5))) / (self.img_w - 1)
# init priors on feature map
priors_on_featmap = priors.clone()[..., 6 + self.sample_x_indexs]
return priors, priors_on_featmap
def _init_prior_embeddings(self):
# [start_y, start_x, theta] -> all normalize
self.prior_embeddings = nn.Embedding(self.num_priors, 3)
bottom_priors_nums = self.num_priors * 3 // 4
left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8
strip_size = 0.5 / (left_priors_nums // 2 - 1)
bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1)
for i in range(left_priors_nums):
nn.init.constant_(self.prior_embeddings.weight[i, 0],
(i // 2) * strip_size)
nn.init.constant_(self.prior_embeddings.weight[i, 1], 0.)
nn.init.constant_(self.prior_embeddings.weight[i, 2],
0.16 if i % 2 == 0 else 0.32)
for i in range(left_priors_nums,
left_priors_nums + bottom_priors_nums):
nn.init.constant_(self.prior_embeddings.weight[i, 0], 0.)
nn.init.constant_(self.prior_embeddings.weight[i, 1],
((i - left_priors_nums) // 4 + 1) *
bottom_strip_size)
nn.init.constant_(self.prior_embeddings.weight[i, 2],
0.2 * (i % 4 + 1))
for i in range(left_priors_nums + bottom_priors_nums, self.num_priors):
nn.init.constant_(
self.prior_embeddings.weight[i, 0],
((i - left_priors_nums - bottom_priors_nums) // 2) *
strip_size)
nn.init.constant_(self.prior_embeddings.weight[i, 1], 1.)
nn.init.constant_(self.prior_embeddings.weight[i, 2],
0.68 if i % 2 == 0 else 0.84)
# forward function here
def forward(self, x, **kwargs):
'''
Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes.
Each feature is a 4D tensor.
Args:
x: input features (list[Tensor])
Return:
prediction_list: each layer's prediction result
seg: segmentation result for auxiliary loss
'''
batch_features = list(x[len(x) - self.refine_layers:])
batch_features.reverse()
batch_size = batch_features[-1].shape[0]
if self.training:
self.priors, self.priors_on_featmap = self.generate_priors_from_embeddings()
priors, priors_on_featmap = self.priors.repeat(batch_size, 1,
1), self.priors_on_featmap.repeat(
batch_size, 1, 1)
predictions_lists = []
# iterative refine
prior_features_stages = []
for stage in range(self.refine_layers):
num_priors = priors_on_featmap.shape[1]
prior_xs = torch.flip(priors_on_featmap, dims=[2])
batch_prior_features = self.pool_prior_features(
batch_features[stage], num_priors, prior_xs)
prior_features_stages.append(batch_prior_features)
fc_features = self.roi_gather(prior_features_stages,
batch_features[stage], stage)
fc_features = fc_features.view(num_priors, batch_size,
-1).reshape(batch_size * num_priors,
self.fc_hidden_dim)
cls_features = fc_features.clone()
reg_features = fc_features.clone()
for cls_layer in self.cls_modules:
cls_features = cls_layer(cls_features)
for reg_layer in self.reg_modules:
reg_features = reg_layer(reg_features)
cls_logits = self.cls_layers(cls_features)
reg = self.reg_layers(reg_features)
cls_logits = cls_logits.reshape(
batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2)
reg = reg.reshape(batch_size, -1, reg.shape[1])
predictions = priors.clone()
predictions[:, :, :2] = cls_logits
predictions[:, :,
2:5] += reg[:, :, :3] # also reg theta angle here
predictions[:, :, 5] = reg[:, :, 3] # length
def tran_tensor(t):
return t.unsqueeze(2).clone().repeat(1, 1, self.n_offsets)
predictions[..., 6:] = (
tran_tensor(predictions[..., 3]) * (self.img_w - 1) +
((1 - self.prior_ys.repeat(batch_size, num_priors, 1) -
tran_tensor(predictions[..., 2])) * self.img_h /
torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5))) / (self.img_w - 1)
prediction_lines = predictions.clone()
predictions[..., 6:] += reg[..., 4:]
predictions_lists.append(predictions)
if stage != self.refine_layers - 1:
priors = prediction_lines.detach().clone()
priors_on_featmap = priors[..., 6 + self.sample_x_indexs]
if self.training:
seg = None
seg_features = torch.cat([
F.interpolate(feature,
size=[
batch_features[-1].shape[2],
batch_features[-1].shape[3]
],
mode='bilinear',
align_corners=False)
for feature in batch_features
],
dim=1)
seg = self.seg_decoder(seg_features)
output = {'predictions_lists': predictions_lists, 'seg': seg}
return self.loss(output, kwargs['batch'])
return predictions_lists[-1]
def predictions_to_pred(self, predictions):
'''
Convert predictions to internal Lane structure for evaluation.
'''
self.prior_ys = self.prior_ys.to(predictions.device)
self.prior_ys = self.prior_ys.double()
lanes = []
for lane in predictions:
lane_xs = lane[6:] # normalized value
start = min(max(0, int(round(lane[2].item() * self.n_strips))),
self.n_strips)
length = int(round(lane[5].item()))
end = start + length - 1
end = min(end, len(self.prior_ys) - 1)
# end = label_end
# if the prediction does not start at the bottom of the image,
# extend its prediction until the x is outside the image
mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.)
).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool))
lane_xs[end + 1:] = -2
lane_xs[:start][mask] = -2
lane_ys = self.prior_ys[lane_xs >= 0]
lane_xs = lane_xs[lane_xs >= 0]
lane_xs = lane_xs.flip(0).double()
lane_ys = lane_ys.flip(0)
lane_ys = (lane_ys * (self.cfg.ori_img_h - self.cfg.cut_height) +
self.cfg.cut_height) / self.cfg.ori_img_h
if len(lane_xs) <= 1:
continue
points = torch.stack(
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
dim=1).squeeze(2)
lane = Lane(points=points.cpu().numpy(),
metadata={
'start_x': lane[3],
'start_y': lane[2],
'conf': lane[1]
})
lanes.append(lane)
return lanes
def loss(self,
output,
batch,
cls_loss_weight=2.,
xyt_loss_weight=0.5,
iou_loss_weight=2.,
seg_loss_weight=1.):
if self.cfg.haskey('cls_loss_weight'):
cls_loss_weight = self.cfg.cls_loss_weight
if self.cfg.haskey('xyt_loss_weight'):
xyt_loss_weight = self.cfg.xyt_loss_weight
if self.cfg.haskey('iou_loss_weight'):
iou_loss_weight = self.cfg.iou_loss_weight
if self.cfg.haskey('seg_loss_weight'):
seg_loss_weight = self.cfg.seg_loss_weight
predictions_lists = output['predictions_lists']
targets = batch['lane_line'].clone()
cls_criterion = FocalLoss(alpha=0.25, gamma=2.)
cls_loss = 0
reg_xytl_loss = 0
iou_loss = 0
cls_acc = []
cls_acc_stage = []
for stage in range(self.refine_layers):
predictions_list = predictions_lists[stage]
for predictions, target in zip(predictions_list, targets):
target = target[target[:, 1] == 1]
if len(target) == 0:
# If there are no targets, all predictions have to be negatives (i.e., 0 confidence)
cls_target = predictions.new_zeros(predictions.shape[0]).long()
cls_pred = predictions[:, :2]
cls_loss = cls_loss + cls_criterion(
cls_pred, cls_target).sum()
continue
with torch.no_grad():
matched_row_inds, matched_col_inds = assign(
predictions, target, self.img_w, self.img_h)
# classification targets
cls_target = predictions.new_zeros(predictions.shape[0]).long()
cls_target[matched_row_inds] = 1
cls_pred = predictions[:, :2]
# regression targets -> [start_y, start_x, theta] (all transformed to absolute values), only on matched pairs
reg_yxtl = predictions[matched_row_inds, 2:6]
reg_yxtl[:, 0] *= self.n_strips
reg_yxtl[:, 1] *= (self.img_w - 1)
reg_yxtl[:, 2] *= 180
reg_yxtl[:, 3] *= self.n_strips
target_yxtl = target[matched_col_inds, 2:6].clone()
# regression targets -> S coordinates (all transformed to absolute values)
reg_pred = predictions[matched_row_inds, 6:]
reg_pred *= (self.img_w - 1)
reg_targets = target[matched_col_inds, 6:].clone()
with torch.no_grad():
predictions_starts = torch.clamp(
(predictions[matched_row_inds, 2] *
self.n_strips).round().long(), 0,
self.n_strips) # ensure the predictions starts is valid
target_starts = (target[matched_col_inds, 2] *
self.n_strips).round().long()
target_yxtl[:, -1] -= (predictions_starts - target_starts
) # reg length
# Loss calculation
cls_loss = cls_loss + cls_criterion(cls_pred, cls_target).sum(
) / target.shape[0]
target_yxtl[:, 0] *= self.n_strips
target_yxtl[:, 2] *= 180
reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss(
reg_yxtl, target_yxtl,
reduction='none').mean()
iou_loss = iou_loss + liou_loss(
reg_pred, reg_targets,
self.img_w, length=15)
# calculate acc
cls_accuracy = accuracy(cls_pred, cls_target)
cls_acc_stage.append(cls_accuracy)
cls_acc.append(sum(cls_acc_stage) / len(cls_acc_stage))
# extra segmentation loss
seg_loss = self.criterion(F.log_softmax(output['seg'], dim=1),
batch['seg'].long())
cls_loss /= (len(targets) * self.refine_layers)
reg_xytl_loss /= (len(targets) * self.refine_layers)
iou_loss /= (len(targets) * self.refine_layers)
loss = cls_loss * cls_loss_weight + reg_xytl_loss * xyt_loss_weight \
+ seg_loss * seg_loss_weight + iou_loss * iou_loss_weight
return_value = {
'loss': loss,
'loss_stats': {
'loss': loss,
'cls_loss': cls_loss * cls_loss_weight,
'reg_xytl_loss': reg_xytl_loss * xyt_loss_weight,
'seg_loss': seg_loss * seg_loss_weight,
'iou_loss': iou_loss * iou_loss_weight
}
}
for i in range(self.refine_layers):
return_value['loss_stats']['stage_{}_acc'.format(i)] = cls_acc[i]
return return_value
def get_lanes(self, output, as_lanes=True):
'''
Convert model output to lanes.
'''
softmax = nn.Softmax(dim=1)
decoded = []
for predictions in output:
# filter out the conf lower than conf threshold
threshold = self.cfg.test_parameters.conf_threshold
scores = softmax(predictions[:, :2])[:, 1]
keep_inds = scores >= threshold
predictions = predictions[keep_inds]
scores = scores[keep_inds]
if predictions.shape[0] == 0:
decoded.append([])
continue
nms_predictions = predictions.detach().clone()
nms_predictions = torch.cat(
[nms_predictions[..., :4], nms_predictions[..., 5:]], dim=-1)
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
nms_predictions[...,
5:] = nms_predictions[..., 5:] * (self.img_w - 1)
keep, num_to_keep, _ = nms(
nms_predictions,
scores,
overlap=self.cfg.test_parameters.nms_thres,
top_k=self.cfg.max_lanes)
keep = keep[:num_to_keep]
predictions = predictions[keep]
if predictions.shape[0] == 0:
decoded.append([])
continue
predictions[:, 5] = torch.round(predictions[:, 5] * self.n_strips)
if as_lanes:
pred = self.predictions_to_pred(predictions)
else:
pred = predictions
decoded.append(pred)
return decoded

View File

@@ -0,0 +1,77 @@
import mmcv
import torch.nn as nn
@mmcv.jit(coderize=True)
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == 2 and target.ndim == 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
class Accuracy(nn.Module):
def __init__(self, topk=(1, ), thresh=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)

View File

@@ -0,0 +1,191 @@
# pylint: disable-all
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# Source: https://github.com/kornia/kornia/blob/f4f70fefb63287f72bc80cd96df9c061b1cb60dd/kornia/losses/focal.py
class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
super(SoftmaxFocalLoss, self).__init__()
self.gamma = gamma
self.nll = nn.NLLLoss(ignore_index=ignore_lb)
def forward(self, logits, labels):
scores = F.softmax(logits, dim=1)
factor = torch.pow(1. - scores, self.gamma)
log_score = F.log_softmax(logits, dim=1)
log_score = factor * log_score
loss = self.nll(log_score, labels)
return loss
def one_hot(labels: torch.Tensor,
num_classes: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
where N is batch size. Each value is an integer
representing correct classification.
num_classes (int): number of classes in labels.
device (Optional[torch.device]): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see torch.set_default_tensor_type()). device will be the CPU for CPU
tensor types and the current CUDA device for CUDA tensor types.
dtype (Optional[torch.dtype]): the desired data type of returned
tensor. Default: if None, infers data type from values.
Returns:
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
Examples::
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> kornia.losses.one_hot(labels, num_classes=3)
tensor([[[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 0.]],
[[0., 0.],
[1., 0.]]]]
"""
if not torch.is_tensor(labels):
raise TypeError(
"Input labels type is not a torch.Tensor. Got {}".format(
type(labels)))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}".format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
" Got: {}".format(num_classes))
shape = labels.shape
one_hot = torch.zeros(shape[0],
num_classes,
*shape[1:],
device=device,
dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
def focal_loss(input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'none',
eps: float = 1e-8) -> torch.Tensor:
r"""Function that computes Focal loss.
See :class:`~kornia.losses.FocalLoss` for details.
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(input)))
if not len(input.shape) >= 2:
raise ValueError(
"Invalid input shape, we expect BxCx*. Got: {}".format(
input.shape))
if input.size(0) != target.size(0):
raise ValueError(
'Expected input batch_size ({}) to match target batch_size ({}).'.
format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n, ) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {} and {}".
format(input.device, target.device))
# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1) + eps
# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(target,
num_classes=input.shape[1],
device=input.device,
dtype=input.dtype)
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alpha * weight * torch.log(input_soft)
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError(
"Invalid reduction mode: {}".format(reduction))
return loss
class FocalLoss(nn.Module):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: none | mean | sum. none: no reduction will be applied,
mean: the sum of the output will be divided by the number of elements
in the output, sum: the output will be summed. Default: none.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C1`.
Examples:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> loss = kornia.losses.FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
def __init__(self,
alpha: float,
gamma: float = 2.0,
reduction: str = 'none') -> None:
super(FocalLoss, self).__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float = 1e-6
def forward( # type: ignore
self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return focal_loss(input, target, self.alpha, self.gamma,
self.reduction, self.eps)

View File

@@ -0,0 +1,38 @@
import torch
def line_iou(pred, target, img_w, length=15, aligned=True):
'''
Calculate the line iou value between predictions and targets
Args:
pred: lane predictions, shape: (num_pred, 72)
target: ground truth, shape: (num_target, 72)
img_w: image width
length: extended radius
aligned: True for iou loss calculation, False for pair-wise ious in assign
'''
px1 = pred - length
px2 = pred + length
tx1 = target - length
tx2 = target + length
if aligned:
invalid_mask = target
ovr = torch.min(px2, tx2) - torch.max(px1, tx1)
union = torch.max(px2, tx2) - torch.min(px1, tx1)
else:
num_pred = pred.shape[0]
invalid_mask = target.repeat(num_pred, 1, 1)
ovr = (torch.min(px2[:, None, :], tx2[None, ...]) -
torch.max(px1[:, None, :], tx1[None, ...]))
union = (torch.max(px2[:, None, :], tx2[None, ...]) -
torch.min(px1[:, None, :], tx1[None, ...]))
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
ovr[invalid_masks] = 0.
union[invalid_masks] = 0.
iou = ovr.sum(dim=-1) / (union.sum(dim=-1) + 1e-9)
return iou
def liou_loss(pred, target, img_w, length=15):
return (1 - line_iou(pred, target, img_w, length)).mean()

View File

@@ -0,0 +1,2 @@
from .fpn import FPN
from .pafpn import PAFPN

View File

@@ -0,0 +1,167 @@
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..registry import NECKS
@NECKS.register_module
class FPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=True,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
attention=False,
act_cfg=None,
upsample_cfg=dict(mode='nearest'),
init_cfg=dict(type='Xavier',
layer='Conv2d',
distribution='uniform'),
cfg=None):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.attention = attention
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
if extra_convs_on_inputs:
# TODO: deprecate `extra_convs_on_inputs`
warnings.simplefilter('once')
warnings.warn(
'"extra_convs_on_inputs" will be deprecated in v2.9.0,'
'Please use "add_extra_convs"', DeprecationWarning)
self.add_extra_convs = 'on_input'
else:
self.add_extra_convs = 'on_output'
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
def forward(self, inputs):
"""Forward function."""
assert len(inputs) >= len(self.in_channels)
if len(inputs) > len(self.in_channels):
for _ in range(len(inputs) - len(self.in_channels)):
del inputs[0]
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(laterals[i],
size=prev_shape,
**self.upsample_cfg)
# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)

View File

@@ -0,0 +1,154 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from ..registry import NECKS
from .fpn import FPN
@NECKS.register_module
class PAFPN(FPN):
"""Path Aggregation Network for Instance Segmentation.
This is an implementation of the `PAFPN in Path Aggregation Network
<https://arxiv.org/abs/1803.01534>`_.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool): Whether to add conv layers on top of the
original feature maps. Default: False.
extra_convs_on_inputs (bool): Whether to apply extra conv on
the original feature from the backbone. Default: False.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=True,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
cfg=None,
attention=False):
super(PAFPN, self).__init__(in_channels,
out_channels,
num_outs,
start_level,
end_level,
add_extra_convs,
extra_convs_on_inputs,
relu_before_extra_convs,
no_norm_on_lateral,
conv_cfg,
norm_cfg,
attention,
act_cfg,
cfg=cfg)
# add extra bottom up pathway
self.downsample_convs = nn.ModuleList()
self.pafpn_convs = nn.ModuleList()
for i in range(self.start_level + 1, self.backbone_end_level):
d_conv = ConvModule(out_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
pafpn_conv = ConvModule(out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.downsample_convs.append(d_conv)
self.pafpn_convs.append(pafpn_conv)
def forward(self, inputs):
"""Forward function."""
assert len(inputs) >= len(self.in_channels)
if len(inputs) > len(self.in_channels):
for _ in range(len(inputs) - len(self.in_channels)):
del inputs[0]
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(laterals[i],
size=prev_shape,
mode='nearest')
# build outputs
# part 1: from original levels
inter_outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add bottom-up path
for i in range(0, used_backbone_levels - 1):
inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i])
outs = []
outs.append(inter_outs[0])
outs.extend([
self.pafpn_convs[i - 1](inter_outs[i])
for i in range(1, used_backbone_levels)
])
# part 3: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
orig = inputs[self.backbone_end_level - 1]
outs.append(self.fpn_convs[used_backbone_levels](orig))
elif self.add_extra_convs == 'on_lateral':
outs.append(self.fpn_convs[used_backbone_levels](
laterals[-1]))
elif self.add_extra_convs == 'on_output':
outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
else:
raise NotImplementedError
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)

View File

@@ -0,0 +1 @@
from .detector import Detector

View File

@@ -0,0 +1,36 @@
import torch.nn as nn
import torch
from clrnet.models.registry import NETS
from ..registry import build_backbones, build_aggregator, build_heads, build_necks
@NETS.register_module
class Detector(nn.Module):
def __init__(self, cfg):
super(Detector, self).__init__()
self.cfg = cfg
self.backbone = build_backbones(cfg)
self.aggregator = build_aggregator(cfg) if cfg.haskey('aggregator') else None
self.neck = build_necks(cfg) if cfg.haskey('neck') else None
self.heads = build_heads(cfg)
def get_lanes(self):
return self.heads.get_lanes(output)
def forward(self, batch):
output = {}
fea = self.backbone(batch['img'] if isinstance(batch, dict) else batch)
if self.aggregator:
fea[-1] = self.aggregator(fea[-1])
if self.neck:
fea = self.neck(fea)
if self.training:
output = self.heads(fea, batch=batch)
else:
output = self.heads(fea)
return output

View File

@@ -0,0 +1,45 @@
from clrnet.utils import Registry, build_from_cfg
import torch.nn as nn
BACKBONES = Registry('backbones')
AGGREGATORS = Registry('aggregators')
HEADS = Registry('heads')
NECKS = Registry('necks')
NETS = Registry('nets')
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbones(cfg):
return build(cfg.backbone, BACKBONES, default_args=dict(cfg=cfg))
def build_necks(cfg):
return build(cfg.necks, NECKS, default_args=dict(cfg=cfg))
def build_aggregator(cfg):
return build(cfg.aggregator, AGGREGATORS, default_args=dict(cfg=cfg))
def build_heads(cfg):
return build(cfg.heads, HEADS, default_args=dict(cfg=cfg))
def build_head(split_cfg, cfg):
return build(split_cfg, HEADS, default_args=dict(cfg=cfg))
def build_net(cfg):
return build(cfg.net, NETS, default_args=dict(cfg=cfg))
def build_necks(cfg):
return build(cfg.neck, NECKS, default_args=dict(cfg=cfg))

View File

@@ -0,0 +1,140 @@
import torch
from clrnet.models.losses.lineiou_loss import line_iou
def distance_cost(predictions, targets, img_w):
"""
repeat predictions and targets to generate all combinations
use the abs distance as the new distance cost
"""
num_priors = predictions.shape[0]
num_targets = targets.shape[0]
predictions = torch.repeat_interleave(
predictions, num_targets, dim=0
)[...,
6:] # repeat_interleave'ing [a, b] 2 times gives [a, a, b, b] ((np + nt) * 78)
targets = torch.cat(
num_priors *
[targets])[...,
6:] # applying this 2 times on [c, d] gives [c, d, c, d]
invalid_masks = (targets < 0) | (targets >= img_w)
lengths = (~invalid_masks).sum(dim=1)
distances = torch.abs((targets - predictions))
distances[invalid_masks] = 0.
distances = distances.sum(dim=1) / (lengths.float() + 1e-9)
distances = distances.view(num_priors, num_targets)
return distances
def focal_cost(cls_pred, gt_labels, alpha=0.25, gamma=2, eps=1e-12):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value
"""
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + eps).log() * (1 - alpha) * cls_pred.pow(gamma)
pos_cost = -(cls_pred + eps).log() * alpha * (1 - cls_pred).pow(gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost
def dynamic_k_assign(cost, pair_wise_ious):
"""
Assign grouth truths with priors dynamically.
Args:
cost: the assign cost.
pair_wise_ious: iou of grouth truth and priors.
Returns:
prior_idx: the index of assigned prior.
gt_idx: the corresponding ground truth index.
"""
matching_matrix = torch.zeros_like(cost)
ious_matrix = pair_wise_ious
ious_matrix[ious_matrix < 0] = 0.
n_candidate_k = 4
topk_ious, _ = torch.topk(ious_matrix, n_candidate_k, dim=0)
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
num_gt = cost.shape[1]
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(cost[:, gt_idx],
k=dynamic_ks[gt_idx].item(),
largest=False)
matching_matrix[pos_idx, gt_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
matched_gt = matching_matrix.sum(1)
if (matched_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[matched_gt > 1, :], dim=1)
matching_matrix[matched_gt > 1, 0] *= 0.0
matching_matrix[matched_gt > 1, cost_argmin] = 1.0
prior_idx = matching_matrix.sum(1).nonzero()
gt_idx = matching_matrix[prior_idx].argmax(-1)
return prior_idx.flatten(), gt_idx.flatten()
def assign(
predictions,
targets,
img_w,
img_h,
distance_cost_weight=3.,
cls_cost_weight=1.,
):
'''
computes dynamicly matching based on the cost, including cls cost and lane similarity cost
Args:
predictions (Tensor): predictions predicted by each stage, shape: (num_priors, 78)
targets (Tensor): lane targets, shape: (num_targets, 78)
return:
matched_row_inds (Tensor): matched predictions, shape: (num_targets)
matched_col_inds (Tensor): matched targets, shape: (num_targets)
'''
predictions = predictions.detach().clone()
predictions[:, 3] *= (img_w - 1)
predictions[:, 6:] *= (img_w - 1)
targets = targets.detach().clone()
# distances cost
distances_score = distance_cost(predictions, targets, img_w)
distances_score = 1 - (distances_score / torch.max(distances_score)
) + 1e-2 # normalize the distance
# classification cost
cls_score = focal_cost(predictions[:, :2], targets[:, 1].long())
num_priors = predictions.shape[0]
num_targets = targets.shape[0]
target_start_xys = targets[:, 2:4] # num_targets, 2
target_start_xys[..., 0] *= (img_h - 1)
prediction_start_xys = predictions[:, 2:4]
prediction_start_xys[..., 0] *= (img_h - 1)
start_xys_score = torch.cdist(prediction_start_xys, target_start_xys,
p=2).reshape(num_priors, num_targets)
start_xys_score = (1 - start_xys_score / torch.max(start_xys_score)) + 1e-2
target_thetas = targets[:, 4].unsqueeze(-1)
theta_score = torch.cdist(predictions[:, 4].unsqueeze(-1),
target_thetas,
p=1).reshape(num_priors, num_targets) * 180
theta_score = (1 - theta_score / torch.max(theta_score)) + 1e-2
cost = -(distances_score * start_xys_score * theta_score
)**2 * distance_cost_weight + cls_score * cls_cost_weight
iou = line_iou(predictions[..., 6:], targets[..., 6:], img_w, aligned=False)
matched_row_inds, matched_col_inds = dynamic_k_assign(cost, iou)
return matched_row_inds, matched_col_inds

View File

@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
def LinearModule(hidden_dim):
return nn.ModuleList(
[nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True)])
class FeatureResize(nn.Module):
def __init__(self, size=(10, 25)):
super(FeatureResize, self).__init__()
self.size = size
def forward(self, x):
x = F.interpolate(x, self.size)
return x.flatten(2)
class ROIGather(nn.Module):
'''
ROIGather module for gather global information
Args:
in_channels: prior feature channels
num_priors: prior numbers we predefined
sample_points: the number of sampled points when we extract feature from line
fc_hidden_dim: the fc output channel
refine_layers: the total number of layers to build refine
'''
def __init__(self,
in_channels,
num_priors,
sample_points,
fc_hidden_dim,
refine_layers,
mid_channels=48):
super(ROIGather, self).__init__()
self.in_channels = in_channels
self.num_priors = num_priors
self.f_key = ConvModule(in_channels=self.in_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=dict(type='BN'))
self.f_query = nn.Sequential(
nn.Conv1d(in_channels=num_priors,
out_channels=num_priors,
kernel_size=1,
stride=1,
padding=0,
groups=num_priors),
nn.ReLU(),
)
self.f_value = nn.Conv2d(in_channels=self.in_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0)
self.W = nn.Conv1d(in_channels=num_priors,
out_channels=num_priors,
kernel_size=1,
stride=1,
padding=0,
groups=num_priors)
self.resize = FeatureResize()
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.convs = nn.ModuleList()
self.catconv = nn.ModuleList()
for i in range(refine_layers):
self.convs.append(
ConvModule(in_channels,
mid_channels, (9, 1),
padding=(4, 0),
bias=False,
norm_cfg=dict(type='BN')))
self.catconv.append(
ConvModule(mid_channels * (i + 1),
in_channels, (9, 1),
padding=(4, 0),
bias=False,
norm_cfg=dict(type='BN')))
self.fc = nn.Linear(sample_points * fc_hidden_dim, fc_hidden_dim)
self.fc_norm = nn.LayerNorm(fc_hidden_dim)
def roi_fea(self, x, layer_index):
feats = []
for i, feature in enumerate(x):
feat_trans = self.convs[i](feature)
feats.append(feat_trans)
cat_feat = torch.cat(feats, dim=1)
cat_feat = self.catconv[layer_index](cat_feat)
return cat_feat
def forward(self, roi_features, x, layer_index):
'''
Args:
roi_features: prior feature, shape: (Batch * num_priors, prior_feat_channel, sample_point, 1)
x: feature map
layer_index: currently on which layer to refine
Return:
roi: prior features with gathered global information, shape: (Batch, num_priors, fc_hidden_dim)
'''
roi = self.roi_fea(roi_features, layer_index)
bs = x.size(0)
roi = roi.contiguous().view(bs * self.num_priors, -1)
roi = F.relu(self.fc_norm(self.fc(roi)))
roi = roi.view(bs, self.num_priors, -1)
query = roi
value = self.resize(self.f_value(x))
query = self.f_query(query)
key = self.f_key(x)
value = value.permute(0, 2, 1)
key = self.resize(key)
sim_map = torch.matmul(query, key)
sim_map = (self.in_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context = torch.matmul(sim_map, value)
context = self.W(context)
roi = roi + F.dropout(context, p=0.1, training=self.training)
return roi

View File

@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.functional as F
class SegDecoder(nn.Module):
'''
Optionaly seg decoder
'''
def __init__(self,
image_height,
image_width,
num_class,
prior_feat_channels=64,
refine_layers=3):
super().__init__()
self.dropout = nn.Dropout2d(0.1)
self.conv = nn.Conv2d(prior_feat_channels * refine_layers, num_class,
1)
self.image_height = image_height
self.image_width = image_width
def forward(self, x):
x = self.dropout(x)
x = self.conv(x)
x = F.interpolate(x,
size=[self.image_height, self.image_width],
mode='bilinear',
align_corners=False)
return x