Files
HSAP/algorithms/lane_ufld/code/UFLD/model/backbone.py
Chengfang Lu 7c43b44c57 feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 16:59:59 +08:00

99 lines
3.4 KiB
Python
Executable File

import torch
import torch.nn.modules
import torchvision
from model.vovnet import VOVNET_ALIASES, STAGE_SPECS
RESNET_BACKBONES = ['18', '34', '50', '101', '152', '50next', '101next', '50wide', '101wide']
VOVMNET_BACKBONES = list(VOVNET_ALIASES.keys())
SUPPORTED_BACKBONES = RESNET_BACKBONES + VOVMNET_BACKBONES + ['vgg']
def is_vovnet(backbone):
return str(backbone) in VOVNET_ALIASES
def get_backbone_spec(backbone):
"""Channel layout for parsingNet pool / aux heads."""
backbone = str(backbone)
if is_vovnet(backbone):
from model.vovnet import VOVNET_ALIASES as _aliases
variant = _aliases[backbone]
ch = STAGE_SPECS[variant]['stage_out_ch']
return {
'pool_in': ch[3],
'aux': (ch[1], ch[2], ch[3]),
'family': 'vov',
}
if backbone in ['18', '34']:
return {'pool_in': 512, 'aux': (128, 256, 512), 'family': 'resnet'}
if backbone == 'vgg':
return {'pool_in': 512, 'aux': (128, 256, 512), 'family': 'vgg'}
return {'pool_in': 2048, 'aux': (512, 1024, 2048), 'family': 'resnet'}
def build_backbone(backbone, pretrained=False):
backbone = str(backbone)
if is_vovnet(backbone):
from model.vovnet import vovnet
return vovnet(backbone, pretrained=pretrained)
if backbone == 'vgg':
return vgg16bn(pretrained=pretrained)
return resnet(backbone, pretrained=pretrained)
class vgg16bn(torch.nn.Module):
def __init__(self, pretrained=False):
super(vgg16bn, self).__init__()
model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children())
model = model[:33] + model[34:43]
self.model = torch.nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class resnet(torch.nn.Module):
def __init__(self, layers, pretrained=False):
super(resnet, self).__init__()
if layers == '18':
model = torchvision.models.resnet18(pretrained=pretrained)
elif layers == '34':
model = torchvision.models.resnet34(pretrained=pretrained)
elif layers == '50':
model = torchvision.models.resnet50(pretrained=pretrained)
elif layers == '101':
model = torchvision.models.resnet101(pretrained=pretrained)
elif layers == '152':
model = torchvision.models.resnet152(pretrained=pretrained)
elif layers == '50next':
model = torchvision.models.resnext50_32x4d(pretrained=pretrained)
elif layers == '101next':
model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
elif layers == '50wide':
model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
elif layers == '101wide':
model = torchvision.models.wide_resnet101_2(pretrained=pretrained)
else:
raise NotImplementedError(f'Unknown ResNet backbone: {layers}')
self.conv1 = model.conv1
self.bn1 = model.bn1
self.relu = model.relu
self.maxpool = model.maxpool
self.layer1 = model.layer1
self.layer2 = model.layer2
self.layer3 = model.layer3
self.layer4 = model.layer4
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x2 = self.layer2(x)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
return x2, x3, x4