158 lines
5.9 KiB
Python
158 lines
5.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os.path as osp
|
|
from operator import itemgetter
|
|
from typing import Optional, Tuple
|
|
from mmengine import Config, DictAction
|
|
from mmaction.apis import inference_recognizer, init_recognizer
|
|
from mmaction.visualization import ActionVisualizer
|
|
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='MMAction2 demo')
|
|
parser.add_argument('--config',default="configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py", help='test config file path')
|
|
parser.add_argument('--checkpoint',default="../weights/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb_20220906-cd10898e.pth", help='checkpoint file/url')
|
|
parser.add_argument('--video',default="demo/demo.mp4", help='video file/url or rawframes directory')
|
|
parser.add_argument('--label',default="tools/data/kinetics/label_map_k400.txt", help='label file')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. For example, '
|
|
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
|
|
parser.add_argument(
|
|
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
|
|
parser.add_argument(
|
|
'--fps',
|
|
default=30,
|
|
type=int,
|
|
help='specify fps value of the output video when using rawframes to '
|
|
'generate file')
|
|
parser.add_argument(
|
|
'--font-scale',
|
|
default=None,
|
|
type=float,
|
|
help='font scale of the text in output video')
|
|
parser.add_argument(
|
|
'--font-color',
|
|
default='white',
|
|
help='font color of the text in output video')
|
|
parser.add_argument(
|
|
'--target-resolution',
|
|
nargs=2,
|
|
default=None,
|
|
type=int,
|
|
help='Target resolution (w, h) for resizing the frames when using a '
|
|
'video as input. If either dimension is set to -1, the frames are '
|
|
'resized by keeping the existing aspect ratio')
|
|
parser.add_argument('--out-filename', default=None, help='output filename')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def get_output(
|
|
video_path: str,
|
|
out_filename: str,
|
|
data_sample: str,
|
|
labels: list,
|
|
fps: int = 30,
|
|
font_scale: Optional[str] = None,
|
|
font_color: str = 'white',
|
|
target_resolution: Optional[Tuple[int]] = None,
|
|
) -> None:
|
|
"""Get demo output using ``moviepy``.
|
|
|
|
This function will generate video file or gif file from raw video or
|
|
frames, by using ``moviepy``. For more information of some parameters,
|
|
you can refer to: https://github.com/Zulko/moviepy.
|
|
|
|
Args:
|
|
video_path (str): The video file path.
|
|
out_filename (str): Output filename for the generated file.
|
|
datasample (str): Predicted label of the generated file.
|
|
labels (list): Label list of current dataset.
|
|
fps (int): Number of picture frames to read per second. Defaults to 30.
|
|
font_scale (float): Font scale of the text. Defaults to None.
|
|
font_color (str): Font color of the text. Defaults to ``white``.
|
|
target_resolution (Tuple[int], optional): Set to
|
|
(desired_width desired_height) to have resized frames. If
|
|
either dimension is None, the frames are resized by keeping
|
|
the existing aspect ratio. Defaults to None.
|
|
"""
|
|
|
|
if video_path.startswith(('http://', 'https://')):
|
|
raise NotImplementedError
|
|
|
|
# init visualizer
|
|
out_type = 'gif' if osp.splitext(out_filename)[1] == '.gif' else 'video'
|
|
visualizer = ActionVisualizer()
|
|
visualizer.dataset_meta = dict(classes=labels)
|
|
|
|
text_cfg = {'colors': font_color}
|
|
if font_scale is not None:
|
|
text_cfg.update({'font_sizes': font_scale})
|
|
|
|
visualizer.add_datasample(
|
|
out_filename,
|
|
video_path,
|
|
data_sample,
|
|
draw_pred=True,
|
|
draw_gt=False,
|
|
text_cfg=text_cfg,
|
|
fps=fps,
|
|
out_type=out_type,
|
|
out_path=osp.join('demo', out_filename),
|
|
target_resolution=target_resolution)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
if args.cfg_options is not None:
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
|
|
# Build the recognizer from a config file and checkpoint file/url
|
|
model = init_recognizer(cfg, args.checkpoint, device=args.device)
|
|
pred_result = inference_recognizer(model, args.video)
|
|
|
|
pred_scores = pred_result.pred_score.tolist()
|
|
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
|
|
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
|
|
top5_label = score_sorted[:5]
|
|
|
|
labels = open(args.label).readlines()
|
|
labels = [x.strip() for x in labels]
|
|
results = [(labels[k[0]], k[1]) for k in top5_label]
|
|
|
|
print('The top-5 labels with corresponding scores are:')
|
|
for result in results:
|
|
print(f'{result[0]}: ', result[1])
|
|
|
|
if args.out_filename is not None:
|
|
|
|
if args.target_resolution is not None:
|
|
if args.target_resolution[0] == -1:
|
|
assert isinstance(args.target_resolution[1], int)
|
|
assert args.target_resolution[1] > 0
|
|
if args.target_resolution[1] == -1:
|
|
assert isinstance(args.target_resolution[0], int)
|
|
assert args.target_resolution[0] > 0
|
|
args.target_resolution = tuple(args.target_resolution)
|
|
|
|
get_output(
|
|
args.video,
|
|
args.out_filename,
|
|
pred_result,
|
|
labels,
|
|
fps=args.fps,
|
|
font_scale=args.font_scale,
|
|
font_color=args.font_color,
|
|
target_resolution=args.target_resolution)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|