mdz/pytorch/MMPSPNet/1_scripts/1_save.py

88 lines
3.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import sys
sys.path.append(R"../0_MMPSPNet")
from argparse import ArgumentParser
from mmengine.model import revert_sync_batchnorm
from mmseg.apis import init_model
import torch
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.models.utils import resize
RED = '\033[31m' # 设置前景色为红色
RESET = '\033[0m' # 重置所有属性到默认值
ver = torch.__version__
assert ("1.6" in ver) or ("1.9" in ver), f"{RED}Unsupported PyTorch version: {ver}{RESET}"
def parse_args():
parser = ArgumentParser()
parser.add_argument('--img',default="../0_MMPSPNet/demo/demo.png", help='Image file')
parser.add_argument('--config',default="../0_MMPSPNet/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py", help='Config file')
parser.add_argument('--checkpoint',default="../weights/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth", help='Checkpoint file')
parser.add_argument('--out-file', default="./outputs/", help='Path to output file')
parser.add_argument("--trace_path", type=str, default="../2_compile/fmodel/mmpspnet_traced_1024x2048.pt", help="path of traced model")
parser.add_argument(
'--device', default='cpu', help='Device used for inference')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--title', default='result', help='The image identifier.')
args = parser.parse_args()
return args
def new_forward(self, x):
"""Psp_head.py new Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)#[1,512,1,1]
size_final = x.size()[2:]
upsampled_ppm_out = resize(
ppm_out,
size=(16,32),
mode='bilinear',
align_corners=self.align_corners)##False
upsampled_ppm_out = resize(
upsampled_ppm_out,
size=(64,128),
mode='bilinear',
align_corners=self.align_corners)##False
upsampled_ppm_out = resize(
upsampled_ppm_out,
size=size_final,##[128,256]
mode='bilinear',
align_corners=self.align_corners)##False
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
if __name__ == '__main__':
args = parse_args()
# 加载模型
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
input = torch.randn(1,3,1024,2048)
# 等效替换upsample
PPM.__call__ = new_forward
out = model(input)
print(out.shape)##没有接后处理,最后一个conv:19*512->out:([1, 19, 128,256])
# 删除前处理
del model.data_preprocessor
del model.auxiliary_head
# 将self.dropout置为None, 前向过程中不再调用self.dropout
model.decode_head.dropout = None
# 模型导出
tmodel = torch.jit.trace(model,input)
torch.jit.save(tmodel, args.trace_path)
print("successful save model in ", args.trace_path)