mdz/pytorch/TSN/1_scripts/1_save.py

194 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from operator import itemgetter
from typing import Optional, Tuple
from mmengine import Config, DictAction
from mmaction.apis import inference_recognizer, init_recognizer
from mmaction.visualization import ActionVisualizer
import torch
import torch.nn as nn
# 若需导出段数目的模型,需修改以下参数:
NET_TRACED_PATH = "./traced_model/TSN_traced_25frames_1x3x224x224.onnx" # 模型导出路径
SEGMENTS = 25 # 段数将视频切分为n段源码默认为25
class New_Net(nn.Module):
def __init__(self,backbone,head,cls_head):
super(New_Net, self).__init__()
self.backbone = backbone
self.head = head
self.cls_head = cls_head
def forward(self, x):
features = self.backbone(x)
x0 = self.head(features)#[1,2048,1,1]
x0 = x0.squeeze(-1).squeeze(-1)#[1,2048]
x25 = x0.expand(SEGMENTS, -1)#[25,2048]
x25 = x25.unsqueeze(0) #[1,25,2048]
consensus_out = x25.mean(dim=1,keepdim=True)#AvgConsensus() 输出:[1,1,2048],即[N, 1, in_channels]
tmp = consensus_out.view(consensus_out.size(0), -1) #[1,2048]
out = self.cls_head(tmp)
return out
def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('--config',default="configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py", help='test config file path')
parser.add_argument('--checkpoint',default="../weights/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb_20220906-cd10898e.pth", help='checkpoint file/url')
parser.add_argument('--video',default="demo/demo.mp4", help='video file/url or rawframes directory')
parser.add_argument('--label',default="tools/data/kinetics/label_map_k400.txt", help='label file')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--fps',
default=30,
type=int,
help='specify fps value of the output video when using rawframes to '
'generate file')
parser.add_argument(
'--font-scale',
default=None,
type=float,
help='font scale of the text in output video')
parser.add_argument(
'--font-color',
default='white',
help='font color of the text in output video')
parser.add_argument(
'--target-resolution',
nargs=2,
default=None,
type=int,
help='Target resolution (w, h) for resizing the frames when using a '
'video as input. If either dimension is set to -1, the frames are '
'resized by keeping the existing aspect ratio')
parser.add_argument('--out-filename', default=None, help='output filename')
args = parser.parse_args()
return args
def get_output(
video_path: str,
out_filename: str,
data_sample: str,
labels: list,
fps: int = 30,
font_scale: Optional[str] = None,
font_color: str = 'white',
target_resolution: Optional[Tuple[int]] = None,
) -> None:
"""Get demo output using ``moviepy``.
This function will generate video file or gif file from raw video or
frames, by using ``moviepy``. For more information of some parameters,
you can refer to: https://github.com/Zulko/moviepy.
Args:
video_path (str): The video file path.
out_filename (str): Output filename for the generated file.
datasample (str): Predicted label of the generated file.
labels (list): Label list of current dataset.
fps (int): Number of picture frames to read per second. Defaults to 30.
font_scale (float): Font scale of the text. Defaults to None.
font_color (str): Font color of the text. Defaults to ``white``.
target_resolution (Tuple[int], optional): Set to
(desired_width desired_height) to have resized frames. If
either dimension is None, the frames are resized by keeping
the existing aspect ratio. Defaults to None.
"""
if video_path.startswith(('http://', 'https://')):
raise NotImplementedError
# init visualizer
out_type = 'gif' if osp.splitext(out_filename)[1] == '.gif' else 'video'
visualizer = ActionVisualizer()
visualizer.dataset_meta = dict(classes=labels)
text_cfg = {'colors': font_color}
if font_scale is not None:
text_cfg.update({'font_sizes': font_scale})
visualizer.add_datasample(
out_filename,
video_path,
data_sample,
draw_pred=True,
draw_gt=False,
text_cfg=text_cfg,
fps=fps,
out_type=out_type,
out_path=osp.join('demo', out_filename),
target_resolution=target_resolution)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# Build the recognizer from a config file and checkpoint file/url
model = init_recognizer(cfg, args.checkpoint, device=args.device)
model.eval()
pred_result = inference_recognizer(model, args.video)
#====== trace model ========#
## net1
with torch.no_grad():
Net = New_Net(model.backbone,model.cls_head.avg_pool,model.cls_head.fc_cls).cpu()
tin = torch.randn(1,3,224,224)
Net.eval()
tout_net = Net(tin).cpu()#输出:[1,2048,1,1]
torch.onnx.export(Net, tin.cpu(), NET_TRACED_PATH, opset_version=11)
# tmodel = torch.jit.trace(Net,tin.cpu())
# tmodel.save(NET_TRACED_PATH)
print("successful traced net in",NET_TRACED_PATH)
pred_scores = pred_result.pred_score.tolist()
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top5_label = score_sorted[:5]
labels = open(args.label).readlines()
labels = [x.strip() for x in labels]
results = [(labels[k[0]], k[1]) for k in top5_label]
print('The top-5 labels with corresponding scores are:')
for result in results:
print(f'{result[0]}: ', result[1])
if args.out_filename is not None:
if args.target_resolution is not None:
if args.target_resolution[0] == -1:
assert isinstance(args.target_resolution[1], int)
assert args.target_resolution[1] > 0
if args.target_resolution[1] == -1:
assert isinstance(args.target_resolution[0], int)
assert args.target_resolution[0] > 0
args.target_resolution = tuple(args.target_resolution)
get_output(
args.video,
args.out_filename,
pred_result,
labels,
fps=args.fps,
font_scale=args.font_scale,
font_color=args.font_color,
target_resolution=args.target_resolution)
if __name__ == '__main__':
main()