email:zk@likedge.top
本文介绍了使用nni进行yolo v5的剪枝适配和测试方法
import torch, torchvisionfrom nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner, L2NormPruner,FPGMPruner,ActivationAPoZRankPrunerfrom nni.compression.pytorch.speedup import ModelSpeedupfrom rich import printfrom utils.general import check_img_sizefrom models.common import Convfrom models.experimental import attempt_loadfrom models.yolo import Detectfrom utils.activations import SiLUimport torch.nn as nnfrom nni.compression.pytorch.utils.counter import count_flops_params首先,导入对应的包,接着导入模型
device = device = torch.device("cuda:1")model = attempt_load('/backup/nni/yolov5/output_pruned/deepsort_det20211202.pt', map_location=device, inplace=True, fuse=True) # load FP32 modelmodel.eval()得到model类,此时model的类应该包含所有层的name,信息,这部分信息之后会用到。
for k, m in model.named_modules(): if isinstance(m, Conv): # assign export-friendly activations if isinstance(m.act, nn.SiLU): m.act = SiLU() elif isinstance(m, Detect): m.inplace = False m.onnx_dynamic = False if hasattr(m, 'forward_export'): m.forward = m.forward_export # assign custom forward (optional)
接着遍历模型所有modules,卷积层激活函数不变,即:
class SiLU(nn.Module): # export-friendly version of nn.SiLU() def forward(x): return x * torch.sigmoid(x)关闭onnx动态配置
imgsz = (640, 640)imgsz *= 2 if len(imgsz) == 1 else 1 # expand
gs = int(max(model.stride)) # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiplesim = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetectiondummy_input = im设置输入im
cfg_list = [{'sparsity': 0.3, 'op_types': ['Conv2d'],'op_names': [ 'model.0.conv', 'model.1.conv', 'model.2.cv1.conv', 'model.2.cv2.conv', 'model.2.cv3.conv', 'model.2.m.0.cv1.conv', 'model.2.m.0.cv2.conv', 'model.2.m.1.cv1.conv', 'model.2.m.1.cv2.conv', 'model.2.m.2.cv1.conv', 'model.2.m.2.cv2.conv', 'model.2.m.3.cv1.conv', 'model.2.m.3.cv2.conv', 'model.3.conv', 'model.4.cv1.conv', 'model.4.cv2.conv', 'model.4.cv3.conv', 'model.4.m.0.cv1.conv', 'model.4.m.0.cv2.conv', 'model.4.m.1.cv1.conv', 'model.4.m.1.cv2.conv', 'model.4.m.2.cv1.conv', 'model.4.m.2.cv2.conv', 'model.4.m.3.cv1.conv', 'model.4.m.3.cv2.conv', 'model.4.m.4.cv1.conv', 'model.4.m.4.cv2.conv', 'model.4.m.5.cv1.conv', 'model.4.m.5.cv2.conv', 'model.4.m.6.cv1.conv', 'model.4.m.6.cv2.conv', 'model.4.m.7.cv1.conv', 'model.4.m.7.cv2.conv', 'model.5.conv', 'model.6.cv1.conv', 'model.6.cv2.conv', 'model.6.cv3.conv', 'model.6.m.0.cv1.conv', 'model.6.m.0.cv2.conv', 'model.6.m.1.cv1.conv', 'model.6.m.1.cv2.conv', 'model.6.m.2.cv1.conv', 'model.6.m.2.cv2.conv', 'model.6.m.3.cv1.conv', 'model.6.m.3.cv2.conv', 'model.6.m.4.cv1.conv', 'model.6.m.4.cv2.conv', 'model.6.m.5.cv1.conv', 'model.6.m.5.cv2.conv', 'model.6.m.6.cv1.conv', 'model.6.m.6.cv2.conv', 'model.6.m.7.cv1.conv', 'model.6.m.7.cv2.conv', 'model.6.m.8.cv1.conv', 'model.6.m.8.cv2.conv', 'model.6.m.9.cv1.conv', 'model.6.m.9.cv2.conv', 'model.6.m.10.cv1.conv', 'model.6.m.10.cv2.conv', 'model.6.m.11.cv1.conv', 'model.6.m.11.cv2.conv', 'model.7.conv', 'model.8.cv1.conv', 'model.8.cv2.conv', 'model.8.cv3.conv', 'model.8.m.0.cv1.conv', 'model.8.m.0.cv2.conv', 'model.8.m.1.cv1.conv', 'model.8.m.1.cv2.conv', 'model.8.m.2.cv1.conv', 'model.8.m.2.cv2.conv', 'model.8.m.3.cv1.conv', 'model.8.m.3.cv2.conv', 'model.9.cv1.conv', 'model.9.cv2.conv', 'model.10.conv', 'model.13.cv1.conv', 'model.13.cv2.conv', 'model.13.cv3.conv', 'model.13.m.0.cv1.conv', 'model.13.m.0.cv2.conv', 'model.13.m.1.cv1.conv', 'model.13.m.1.cv2.conv', 'model.13.m.2.cv1.conv', 'model.13.m.2.cv2.conv', 'model.13.m.3.cv1.conv', 'model.13.m.3.cv2.conv', 'model.14.conv', 'model.17.cv1.conv', 'model.17.cv2.conv', 'model.17.cv3.conv', 'model.17.m.0.cv1.conv', 'model.17.m.0.cv2.conv', 'model.17.m.1.cv1.conv', 'model.17.m.1.cv2.conv', 'model.17.m.2.cv1.conv', 'model.17.m.2.cv2.conv', 'model.17.m.3.cv1.conv', 'model.17.m.3.cv2.conv', 'model.18.conv', 'model.20.cv1.conv', 'model.20.cv2.conv', 'model.20.cv3.conv', 'model.20.m.0.cv1.conv', 'model.20.m.0.cv2.conv', 'model.20.m.1.cv1.conv', 'model.20.m.1.cv2.conv', 'model.20.m.2.cv1.conv', 'model.20.m.2.cv2.conv', 'model.20.m.3.cv1.conv', 'model.20.m.3.cv2.conv', 'model.21.conv', 'model.23.cv1.conv', 'model.23.cv2.conv', 'model.23.cv3.conv', 'model.23.m.0.cv1.conv', 'model.23.m.0.cv2.conv', 'model.23.m.1.cv1.conv', 'model.23.m.1.cv2.conv', 'model.23.m.2.cv1.conv', 'model.23.m.2.cv2.conv', 'model.23.m.3.cv1.conv', 'model.23.m.3.cv2.conv' ]}{'op_names':['model.24.m.0','model.24.m.1','model.24.m.2'],'exclude': True }]
设置config to prune,将所有conv加入toprune list,记得将最后detect部分的三个conv过滤。
pruner = L1NormPruner(model, cfg_list)# pruner = L2NormPruner(model, cfg_list)# pruner = FPGMPruner(model, cfg_list)_, masks = pruner.compress()# print(masks)pruner.export_model(model_path='deepsort_yolov5m.pt', mask_path='deepsort_mask.pt')pruner.show_pruned_weights()pruner._unwrap_model()编译方法,运用更新模型掩码。
print("im.shape:",dummy_input.shape)
# 1.start = time.time()for _ in range(100): use_mask_out = model(dummy_input) # print(use_mask_out[0].shape)
print('elapsed time_before_pruned: ', (time.time() - start)*100)
测试模型输出速度。
xxxxxxxxxxm_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file="mask.pt")m_speedup.speedup_model()model.eval()_,__,___ = count_flops_params(model,dummy_input)
torch.save(model,"output_pruned/pruned_deepsortdetv2.pt")
start = time.time()for _ in range(10): use_mask_out = model(dummy_input)
print(get_parameter_number(model))print('elapsed time when use mask: ', (time.time() - start)*100)
保存模型掩码和原模型。
xxxxxxxxxx#剪枝后模型加载model_to_test = torch.load("output_pruned/pruned_deepsortdetv2.pt")
以上