mdz/pytorch/bytetrack/1_scripts/2_save_infer.py

396 lines
15 KiB
Python

import argparse
from email.policy import strict
import os
import os.path as osp
import time
import cv2
import torch
from loguru import logger
from yolox.data.data_augment import preproc
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser():
parser = argparse.ArgumentParser("ByteTrack Demo!")
parser.add_argument(
"--demo", default="video", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument(
#"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video"
"--path", default="./videos/palace.mp4", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
default=True,
help="whether to save the inference result of image/video",
)
# exp file
parser.add_argument(
"-f",
"--exp_file",
default="exps/example/mot/yolox_s_mix_det.py",
type=str,
help="pls input your expriment description file",
)
parser.add_argument("--model_path", default="../2_compile/fmodel/bytetrack_s_608x1088_traced.pt", type=str, help="traced model path")
parser.add_argument("--export", default=True, type=bool, help="export this model")
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=0.25, type=float, help="test conf")
parser.add_argument("--nms", default=0.5, type=float, help="test nms threshold")
parser.add_argument("--fps", default=30, type=int, help="frame rate (fps)")
parser.add_argument(
"--fp16",
dest="fp16",
default=False,
action="store_true",
help="Adopting mix precision evaluating.",
)
parser.add_argument(
"--fuse",
dest="fuse",
default=False,
action="store_true",
help="Fuse conv and bn for testing.",
)
parser.add_argument(
"--trt",
dest="trt",
default=False,
action="store_true",
help="Using TensorRT model for testing.",
)
# tracking args
parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
parser.add_argument(
"--aspect_ratio_thresh", type=float, default=1.6,
help="threshold for filtering out boxes of which aspect ratio are above the given value."
)
parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
parser.add_argument("--trace_inputshape",type=int,nargs='+', default=[608,1088], help="input shape :hw")
return parser
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = osp.join(maindir, filename)
ext = osp.splitext(apath)[1]
if ext in IMAGE_EXT:
image_names.append(apath)
return image_names
def write_results(filename, results):
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids, scores in results:
for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
if track_id < 0:
continue
x1, y1, w, h = tlwh
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
f.write(line)
logger.info('save results to {}'.format(filename))
class Predictor(object):
def __init__(
self,
model,
exp,
trt_file=None,
decoder=None,
device=torch.device("cpu"),
fp16=False
):
self.model = model
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.device = device
self.fp16 = fp16
self.export = False
if trt_file is not None:
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))
x = torch.ones((1, 3, args.trace_inputshape[0], args.trace_inputshape[1]), device=device)
self.model(x)
self.model = model_trt
self.rgb_means = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def decode_outputs(self,outputs, dtype):
grids = []
strides = []
strides_1 = [8, 16, 32]
#hw = [(64, 120), (32, 60), (16, 30)] #对应输入512x960
hw = [(76, 136), (38, 68), (19, 34)] #对应输入608x1088
for (hsize, wsize), stride in zip(hw, strides_1):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(torch.full((*shape, 1), stride))
grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
return outputs
def inference(self, img, timer):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = osp.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preproc(img, args.trace_inputshape, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0).float()
if self.fp16:
img = img.half() # to FP16
with torch.no_grad():
timer.tic()
# 推理
outputs_l = self.model(img)
# 后处理
outputs = []
for i in range(3):
obj_output = outputs_l[0 + 3*i]
reg_output = outputs_l[1 + 3*i]
cls_output = outputs_l[2 + 3*i]
output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
outputs.append(output)
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
self.decode_outputs(outputs, dtype=torch.float32)
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)
#logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
def image_demo(predictor, vis_folder, current_time, args):
if osp.isdir(args.path):
files = get_image_list(args.path)
else:
files = [args.path]
files.sort()
tracker = BYTETracker(args, frame_rate=args.fps)
timer = Timer()
results = []
for frame_id, img_path in enumerate(files, 1):
outputs, img_info = predictor.inference(img_path, timer)
if outputs[0] is not None:
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], args.trace_inputshape)
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(t.score)
# save results
results.append(
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
)
timer.toc()
online_im = plot_tracking(
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time
)
else:
timer.toc()
online_im = img_info['raw_img']
# result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if args.save_result:
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
save_folder = osp.join(vis_folder, timestamp)
os.makedirs(save_folder, exist_ok=True)
cv2.imwrite(osp.join(save_folder, osp.basename(img_path)), online_im)
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
ch = cv2.waitKey(0)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
if args.save_result:
res_file = osp.join(vis_folder, f"{timestamp}.txt")
with open(res_file, 'w') as f:
f.writelines(results)
logger.info(f"save results to {res_file}")
def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
save_folder = osp.join(vis_folder, timestamp)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = osp.join(save_folder, args.path.split("/")[-1])
else:
save_path = osp.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
tracker = BYTETracker(args, frame_rate=30)
timer = Timer()
frame_id = 0
results = []
while True:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame, timer)
if outputs[0] is not None:
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], args.trace_inputshape)
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(t.score)
results.append(
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
)
timer.toc()
online_im = plot_tracking(
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
)
else:
timer.toc()
online_im = img_info['raw_img']
if args.save_result:
vid_writer.write(online_im)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
frame_id += 1
if args.save_result:
res_file = osp.join(vis_folder, f"{timestamp}.txt")
with open(res_file, 'w') as f:
f.writelines(results)
logger.info(f"save results to {res_file}")
def main(exp, args):
if not args.experiment_name:
args.experiment_name = exp.exp_name
output_dir = osp.join(exp.output_dir, args.experiment_name)
os.makedirs(output_dir, exist_ok=True)
if args.save_result:
vis_folder = osp.join(output_dir, "save_track_vis")
os.makedirs(vis_folder, exist_ok=True)
if args.trt:
args.device = "gpu"
args.device = torch.device("cuda" if args.device == "gpu" else "cpu")
logger.info("Args: {}".format(args))
if args.conf is not None:
exp.test_conf = args.conf
if args.nms is not None:
exp.nmsthre = args.nms
#model = exp.get_model().to(args.device)
model = exp.get_model()
model.eval()
# 加载模型
model = torch.jit.load(args.model_path)
if args.fuse:
logger.info("\tFusing model...")
model = fuse_model(model)
if args.fp16:
model = model.half() # to FP16
if args.trt:
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = osp.join(output_dir, "model_trt.pth")
assert osp.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
else:
trt_file = None
decoder = None
predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
if args.export:
predictor.export = True
current_time = time.localtime()
if args.demo == "image":
image_demo(predictor, vis_folder, current_time, args)
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)
if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
main(exp, args)