133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
# -*- coding:utf-8 -*-
|
|
# 该脚本用来保存torchscript模型
|
|
import argparse
|
|
import sys
|
|
import os
|
|
import time
|
|
|
|
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
|
sys.path.append(R"../0_yolov7")
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import models
|
|
from models.experimental import attempt_load
|
|
from utils.activations import Hardswish, SiLU
|
|
from utils.general import set_logging, check_img_size
|
|
from utils.torch_utils import select_device
|
|
from models.common import *
|
|
from models.yolo import *
|
|
|
|
RED = '\033[31m' # 设置前景色为红色
|
|
RESET = '\033[0m' # 重置所有属性到默认值
|
|
ver = torch.__version__
|
|
assert ("1.6" in ver) or ("1.9" in ver), f"{RED}Unsupported PyTorch version: {ver}{RESET}"
|
|
|
|
#----------------------------
|
|
# 等效reorg
|
|
#----------------------------
|
|
REORG_WEIGHTS=torch.zeros(12,3,2,2)
|
|
for i in range(12):
|
|
j=i%3
|
|
k=(i//3)%2
|
|
l=0 if i<6 else 1
|
|
REORG_WEIGHTS[i,j,k,l]=1
|
|
|
|
REORG_BIAS = torch.zeros(12)
|
|
|
|
def reorg_forward(self,x):
|
|
y = F.conv2d(x,REORG_WEIGHTS,bias=REORG_BIAS,stride=2,padding=0,dilation=1,groups=1)
|
|
return y
|
|
|
|
ReOrg.__call__ = reorg_forward
|
|
|
|
|
|
def detect_forward(self, x):
|
|
z = [] # inference output
|
|
self.training |= self.export
|
|
for i in range(self.nl):
|
|
x[i] = self.m[i](x[i]) # conv
|
|
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
|
# x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() # 模型导出时 需要将此行注释掉
|
|
if not self.training: # inference
|
|
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
|
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
|
y = x[i].sigmoid()
|
|
if not torch.onnx.is_in_onnx_export():
|
|
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
|
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
|
else:
|
|
xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
|
|
xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
|
|
wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
|
|
y = torch.cat((xy, wh, conf), 4)
|
|
z.append(y.view(bs, -1, self.no))
|
|
|
|
if self.training:
|
|
out = x
|
|
elif self.end2end:
|
|
out = torch.cat(z, 1)
|
|
elif self.include_nms:
|
|
z = self.convert(z)
|
|
out = (z, )
|
|
elif self.concat:
|
|
out = torch.cat(z, 1)
|
|
else:
|
|
out = (torch.cat(z, 1), x)
|
|
return out
|
|
|
|
Detect.forward = detect_forward
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', type=str, default='..\weights\yolov7.pt', help='weights path')
|
|
parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
|
|
parser.add_argument('--export_dir', type=str, default=R'..\2_compile\fmodel', help='weights path')
|
|
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
|
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
|
|
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
|
|
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
opt = parser.parse_args()
|
|
|
|
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
|
|
|
#print(opt)
|
|
set_logging()
|
|
t = time.time()
|
|
|
|
# Load PyTorch model
|
|
device = select_device(opt.device)
|
|
model = attempt_load(opt.weights, map_location=device) # load FP32 model
|
|
labels = model.names
|
|
|
|
# Checks
|
|
gs = int(max(model.stride)) # grid size (max stride)
|
|
opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
|
|
|
|
# Input
|
|
img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
|
|
|
|
# Update model
|
|
for k, m in model.named_modules():
|
|
m._non_persistent_buffers_set = set() # pytorch 1.9.0 compatibility
|
|
if isinstance(m, models.common.Conv): # assign export-friendly activations
|
|
if isinstance(m.act, nn.Hardswish):
|
|
m.act = Hardswish()
|
|
elif isinstance(m.act, nn.SiLU):
|
|
pass
|
|
|
|
model.model[-1].export = not opt.grid # set Detect() layer grid export
|
|
y = model(img) # dry run
|
|
|
|
# TorchScript export
|
|
try:
|
|
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
|
|
if not os.path.exists(opt.export_dir):
|
|
os.makedirs(opt.export_dir)
|
|
|
|
f = os.path.join(opt.export_dir, opt.weights.split('\\')[-1].split('.')[0] + f'_{str(opt.img_size[0])}x{str(opt.img_size[1])}.pt') # filename
|
|
ts = torch.jit.trace(model, torch.rand(1, 3, opt.img_size[0], opt.img_size[1], dtype=torch.float), strict=False) # img -> torch.rand(1,3,1280,1280,dtype=torch.float)
|
|
ts.save(f)
|
|
print('TorchScript export success, saved as %s' % f)
|
|
except Exception as e:
|
|
print('TorchScript export failure: %s' % e)
|