Files
HSAP/algorithms/lane_ufld/code/UFLD/profile_model.py
Chengfang Lu 7c43b44c57 feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 16:59:59 +08:00

140 lines
4.9 KiB
Python

"""Count parameters and FLOPs for a UFLD parsingNet checkpoint."""
import argparse
import os
import re
import torch
from model.backbone import is_vovnet
from model.model import parsingNet
from utils.common import checkpoint_state_dict
from utils.config import Config
# TuSimple / lane0 UFLD row anchors (fixed in original repo)
CLS_NUM_PER_LANE = 56
INPUT_SIZE = (288, 800)
def parse_cfg_txt(cfg_path):
"""Parse saved cfg.txt from train.py (Config repr line)."""
with open(cfg_path, encoding='utf-8') as f:
text = f.read()
m = re.search(r"\{.*\}", text, re.DOTALL)
if not m:
raise ValueError(f'cannot parse dict from {cfg_path}')
# cfg.txt uses single-quoted python dict repr
return eval(m.group())
def build_net(backbone, griding_num, num_lanes, use_aux, pretrained=False):
return parsingNet(
pretrained=pretrained,
backbone=str(backbone),
cls_dim=(griding_num + 1, CLS_NUM_PER_LANE, num_lanes),
use_aux=use_aux,
)
def count_params(net):
total = sum(p.numel() for p in net.parameters())
trainable = sum(p.numel() for p in net.parameters() if p.requires_grad)
return total, trainable
def count_flops(net, input_size):
x = torch.randn(1, 3, *input_size)
from torch.utils.flop_counter import FlopCounterMode
with FlopCounterMode(display=False) as fc:
with torch.no_grad():
net(x)
total = fc.get_total_flops()
breakdown = fc.get_flop_counts().get('Global', {})
return total, breakdown, tuple(net(x).shape)
def main():
parser = argparse.ArgumentParser(description='UFLD params / FLOPs profiler')
parser.add_argument(
'--run_dir',
default='log/20250702_165153_lr_1e-05_b_32_ufld_2lanes_res18',
help='training log dir containing cfg.txt and optional best.pth',
)
parser.add_argument('--config', default=None, help='config .py (overrides cfg.txt)')
parser.add_argument('--model_path', default=None, help='.pth checkpoint (optional)')
parser.add_argument('--backbone', default=None)
parser.add_argument('--griding_num', type=int, default=None)
parser.add_argument('--num_lanes', type=int, default=None)
parser.add_argument('--use_aux', action='store_true', default=None)
parser.add_argument('--height', type=int, default=INPUT_SIZE[0])
parser.add_argument('--width', type=int, default=INPUT_SIZE[1])
args = parser.parse_args()
backbone = '18'
griding_num = 100
num_lanes = 4
use_aux = False
model_path = args.model_path
if args.config:
cfg = Config.fromfile(args.config)
backbone = cfg.backbone
griding_num = cfg.griding_num
num_lanes = cfg.num_lanes
use_aux = getattr(cfg, 'use_aux', False)
if model_path is None and getattr(cfg, 'test_model', None):
model_path = cfg.test_model
elif args.run_dir and os.path.isfile(os.path.join(args.run_dir, 'cfg.txt')):
cfg = parse_cfg_txt(os.path.join(args.run_dir, 'cfg.txt'))
backbone = cfg.get('backbone', backbone)
griding_num = cfg.get('griding_num', griding_num)
num_lanes = cfg.get('num_lanes', num_lanes)
use_aux = cfg.get('use_aux', use_aux)
if model_path is None:
for name in ('best.pth', 'latest.pth', 'model.pt'):
p = os.path.join(args.run_dir, name)
if os.path.isfile(p):
model_path = p
break
if args.backbone is not None:
backbone = args.backbone
if args.griding_num is not None:
griding_num = args.griding_num
if args.num_lanes is not None:
num_lanes = args.num_lanes
if args.use_aux is not None:
use_aux = args.use_aux
input_size = (args.height, args.width)
net = build_net(backbone, griding_num, num_lanes, use_aux, pretrained=False)
if model_path and os.path.isfile(model_path):
net.load_state_dict(checkpoint_state_dict(model_path, map_location='cpu'), strict=False)
net.eval()
total, trainable = count_params(net)
flops, breakdown, out_shape = count_flops(net, input_size)
print('=== UFLD model profile ===')
if model_path:
print('checkpoint:', os.path.abspath(model_path))
if args.run_dir:
print('run_dir:', os.path.abspath(args.run_dir))
arch = f'VoVNet({backbone})' if is_vovnet(backbone) else f'ResNet{backbone}'
print(
f'arch: {arch}, griding_num={griding_num}, num_lanes={num_lanes}, '
f'use_aux={use_aux}, input=1x3x{input_size[0]}x{input_size[1]}'
)
print(f'output shape: {out_shape}')
print(f'Parameters: {total:,} ({total / 1e6:.4f} M)')
print(f'Trainable: {trainable:,}')
print(f'FLOPs (PyTorch FlopCounterMode, 1 image): {flops:,} ({flops / 1e9:.4f} GFLOPs)')
print('FLOPs breakdown (top ops):')
for op, v in sorted(breakdown.items(), key=lambda kv: -kv[1])[:6]:
print(f' {op}: {v / 1e9:.4f} G')
if __name__ == '__main__':
main()