mdz/pytorch/TSN/1_scripts/0_infer.py

158 lines
5.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from operator import itemgetter
from typing import Optional, Tuple
from mmengine import Config, DictAction
from mmaction.apis import inference_recognizer, init_recognizer
from mmaction.visualization import ActionVisualizer
def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('--config',default="configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py", help='test config file path')
parser.add_argument('--checkpoint',default="../weights/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb_20220906-cd10898e.pth", help='checkpoint file/url')
parser.add_argument('--video',default="demo/demo.mp4", help='video file/url or rawframes directory')
parser.add_argument('--label',default="tools/data/kinetics/label_map_k400.txt", help='label file')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--fps',
default=30,
type=int,
help='specify fps value of the output video when using rawframes to '
'generate file')
parser.add_argument(
'--font-scale',
default=None,
type=float,
help='font scale of the text in output video')
parser.add_argument(
'--font-color',
default='white',
help='font color of the text in output video')
parser.add_argument(
'--target-resolution',
nargs=2,
default=None,
type=int,
help='Target resolution (w, h) for resizing the frames when using a '
'video as input. If either dimension is set to -1, the frames are '
'resized by keeping the existing aspect ratio')
parser.add_argument('--out-filename', default=None, help='output filename')
args = parser.parse_args()
return args
def get_output(
video_path: str,
out_filename: str,
data_sample: str,
labels: list,
fps: int = 30,
font_scale: Optional[str] = None,
font_color: str = 'white',
target_resolution: Optional[Tuple[int]] = None,
) -> None:
"""Get demo output using ``moviepy``.
This function will generate video file or gif file from raw video or
frames, by using ``moviepy``. For more information of some parameters,
you can refer to: https://github.com/Zulko/moviepy.
Args:
video_path (str): The video file path.
out_filename (str): Output filename for the generated file.
datasample (str): Predicted label of the generated file.
labels (list): Label list of current dataset.
fps (int): Number of picture frames to read per second. Defaults to 30.
font_scale (float): Font scale of the text. Defaults to None.
font_color (str): Font color of the text. Defaults to ``white``.
target_resolution (Tuple[int], optional): Set to
(desired_width desired_height) to have resized frames. If
either dimension is None, the frames are resized by keeping
the existing aspect ratio. Defaults to None.
"""
if video_path.startswith(('http://', 'https://')):
raise NotImplementedError
# init visualizer
out_type = 'gif' if osp.splitext(out_filename)[1] == '.gif' else 'video'
visualizer = ActionVisualizer()
visualizer.dataset_meta = dict(classes=labels)
text_cfg = {'colors': font_color}
if font_scale is not None:
text_cfg.update({'font_sizes': font_scale})
visualizer.add_datasample(
out_filename,
video_path,
data_sample,
draw_pred=True,
draw_gt=False,
text_cfg=text_cfg,
fps=fps,
out_type=out_type,
out_path=osp.join('demo', out_filename),
target_resolution=target_resolution)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# Build the recognizer from a config file and checkpoint file/url
model = init_recognizer(cfg, args.checkpoint, device=args.device)
pred_result = inference_recognizer(model, args.video)
pred_scores = pred_result.pred_score.tolist()
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top5_label = score_sorted[:5]
labels = open(args.label).readlines()
labels = [x.strip() for x in labels]
results = [(labels[k[0]], k[1]) for k in top5_label]
print('The top-5 labels with corresponding scores are:')
for result in results:
print(f'{result[0]}: ', result[1])
if args.out_filename is not None:
if args.target_resolution is not None:
if args.target_resolution[0] == -1:
assert isinstance(args.target_resolution[1], int)
assert args.target_resolution[1] > 0
if args.target_resolution[1] == -1:
assert isinstance(args.target_resolution[0], int)
assert args.target_resolution[0] > 0
args.target_resolution = tuple(args.target_resolution)
get_output(
args.video,
args.out_filename,
pred_result,
labels,
fps=args.fps,
font_scale=args.font_scale,
font_color=args.font_color,
target_resolution=args.target_resolution)
if __name__ == '__main__':
main()