194 lines
7.4 KiB
Python
194 lines
7.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
import argparse
|
||
import os.path as osp
|
||
from operator import itemgetter
|
||
from typing import Optional, Tuple
|
||
from mmengine import Config, DictAction
|
||
from mmaction.apis import inference_recognizer, init_recognizer
|
||
from mmaction.visualization import ActionVisualizer
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# 若需导出段数目的模型,需修改以下参数:
|
||
NET_TRACED_PATH = "./traced_model/TSN_traced_25frames_1x3x224x224.onnx" # 模型导出路径
|
||
SEGMENTS = 25 # 段数,将视频切分为n段,源码默认为25
|
||
|
||
class New_Net(nn.Module):
|
||
def __init__(self,backbone,head,cls_head):
|
||
super(New_Net, self).__init__()
|
||
self.backbone = backbone
|
||
self.head = head
|
||
self.cls_head = cls_head
|
||
|
||
def forward(self, x):
|
||
features = self.backbone(x)
|
||
x0 = self.head(features)#[1,2048,1,1]
|
||
x0 = x0.squeeze(-1).squeeze(-1)#[1,2048]
|
||
x25 = x0.expand(SEGMENTS, -1)#[25,2048]
|
||
x25 = x25.unsqueeze(0) #[1,25,2048]
|
||
consensus_out = x25.mean(dim=1,keepdim=True)#AvgConsensus() 输出:[1,1,2048],即[N, 1, in_channels]
|
||
tmp = consensus_out.view(consensus_out.size(0), -1) #[1,2048]
|
||
out = self.cls_head(tmp)
|
||
return out
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description='MMAction2 demo')
|
||
parser.add_argument('--config',default="configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py", help='test config file path')
|
||
parser.add_argument('--checkpoint',default="../weights/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb_20220906-cd10898e.pth", help='checkpoint file/url')
|
||
parser.add_argument('--video',default="demo/demo.mp4", help='video file/url or rawframes directory')
|
||
parser.add_argument('--label',default="tools/data/kinetics/label_map_k400.txt", help='label file')
|
||
parser.add_argument(
|
||
'--cfg-options',
|
||
nargs='+',
|
||
action=DictAction,
|
||
help='override some settings in the used config, the key-value pair '
|
||
'in xxx=yyy format will be merged into config file. For example, '
|
||
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
|
||
parser.add_argument(
|
||
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
|
||
parser.add_argument(
|
||
'--fps',
|
||
default=30,
|
||
type=int,
|
||
help='specify fps value of the output video when using rawframes to '
|
||
'generate file')
|
||
parser.add_argument(
|
||
'--font-scale',
|
||
default=None,
|
||
type=float,
|
||
help='font scale of the text in output video')
|
||
parser.add_argument(
|
||
'--font-color',
|
||
default='white',
|
||
help='font color of the text in output video')
|
||
parser.add_argument(
|
||
'--target-resolution',
|
||
nargs=2,
|
||
default=None,
|
||
type=int,
|
||
help='Target resolution (w, h) for resizing the frames when using a '
|
||
'video as input. If either dimension is set to -1, the frames are '
|
||
'resized by keeping the existing aspect ratio')
|
||
parser.add_argument('--out-filename', default=None, help='output filename')
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def get_output(
|
||
video_path: str,
|
||
out_filename: str,
|
||
data_sample: str,
|
||
labels: list,
|
||
fps: int = 30,
|
||
font_scale: Optional[str] = None,
|
||
font_color: str = 'white',
|
||
target_resolution: Optional[Tuple[int]] = None,
|
||
) -> None:
|
||
"""Get demo output using ``moviepy``.
|
||
|
||
This function will generate video file or gif file from raw video or
|
||
frames, by using ``moviepy``. For more information of some parameters,
|
||
you can refer to: https://github.com/Zulko/moviepy.
|
||
|
||
Args:
|
||
video_path (str): The video file path.
|
||
out_filename (str): Output filename for the generated file.
|
||
datasample (str): Predicted label of the generated file.
|
||
labels (list): Label list of current dataset.
|
||
fps (int): Number of picture frames to read per second. Defaults to 30.
|
||
font_scale (float): Font scale of the text. Defaults to None.
|
||
font_color (str): Font color of the text. Defaults to ``white``.
|
||
target_resolution (Tuple[int], optional): Set to
|
||
(desired_width desired_height) to have resized frames. If
|
||
either dimension is None, the frames are resized by keeping
|
||
the existing aspect ratio. Defaults to None.
|
||
"""
|
||
|
||
if video_path.startswith(('http://', 'https://')):
|
||
raise NotImplementedError
|
||
|
||
# init visualizer
|
||
out_type = 'gif' if osp.splitext(out_filename)[1] == '.gif' else 'video'
|
||
visualizer = ActionVisualizer()
|
||
visualizer.dataset_meta = dict(classes=labels)
|
||
|
||
text_cfg = {'colors': font_color}
|
||
if font_scale is not None:
|
||
text_cfg.update({'font_sizes': font_scale})
|
||
|
||
visualizer.add_datasample(
|
||
out_filename,
|
||
video_path,
|
||
data_sample,
|
||
draw_pred=True,
|
||
draw_gt=False,
|
||
text_cfg=text_cfg,
|
||
fps=fps,
|
||
out_type=out_type,
|
||
out_path=osp.join('demo', out_filename),
|
||
target_resolution=target_resolution)
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
cfg = Config.fromfile(args.config)
|
||
if args.cfg_options is not None:
|
||
cfg.merge_from_dict(args.cfg_options)
|
||
|
||
# Build the recognizer from a config file and checkpoint file/url
|
||
model = init_recognizer(cfg, args.checkpoint, device=args.device)
|
||
model.eval()
|
||
pred_result = inference_recognizer(model, args.video)
|
||
|
||
#====== trace model ========#
|
||
## net1
|
||
with torch.no_grad():
|
||
Net = New_Net(model.backbone,model.cls_head.avg_pool,model.cls_head.fc_cls).cpu()
|
||
tin = torch.randn(1,3,224,224)
|
||
Net.eval()
|
||
tout_net = Net(tin).cpu()#输出:[1,2048,1,1]
|
||
torch.onnx.export(Net, tin.cpu(), NET_TRACED_PATH, opset_version=11)
|
||
# tmodel = torch.jit.trace(Net,tin.cpu())
|
||
# tmodel.save(NET_TRACED_PATH)
|
||
print("successful traced net in",NET_TRACED_PATH)
|
||
|
||
pred_scores = pred_result.pred_score.tolist()
|
||
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
|
||
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
|
||
top5_label = score_sorted[:5]
|
||
|
||
labels = open(args.label).readlines()
|
||
labels = [x.strip() for x in labels]
|
||
results = [(labels[k[0]], k[1]) for k in top5_label]
|
||
|
||
print('The top-5 labels with corresponding scores are:')
|
||
for result in results:
|
||
print(f'{result[0]}: ', result[1])
|
||
|
||
if args.out_filename is not None:
|
||
|
||
if args.target_resolution is not None:
|
||
if args.target_resolution[0] == -1:
|
||
assert isinstance(args.target_resolution[1], int)
|
||
assert args.target_resolution[1] > 0
|
||
if args.target_resolution[1] == -1:
|
||
assert isinstance(args.target_resolution[0], int)
|
||
assert args.target_resolution[0] > 0
|
||
args.target_resolution = tuple(args.target_resolution)
|
||
|
||
get_output(
|
||
args.video,
|
||
args.out_filename,
|
||
pred_result,
|
||
labels,
|
||
fps=args.fps,
|
||
font_scale=args.font_scale,
|
||
font_color=args.font_color,
|
||
target_resolution=args.target_resolution)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|