494 lines
18 KiB
Python
494 lines
18 KiB
Python
import argparse
|
|
from email.policy import strict
|
|
import os
|
|
import os.path as osp
|
|
import time
|
|
import cv2
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from yolox.data.data_augment import preproc
|
|
from yolox.exp import get_exp
|
|
from yolox.utils import fuse_model, get_model_info, postprocess
|
|
from yolox.utils.visualize import plot_tracking
|
|
from yolox.tracker.byte_tracker import BYTETracker
|
|
from yolox.tracking_utils.timer import Timer
|
|
from yolox.models.network_blocks import Focus
|
|
from yolox.models.yolo_head import YOLOXHead
|
|
|
|
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
|
|
|
RED = '\033[31m' # 设置前景色为红色
|
|
RESET = '\033[0m' # 重置所有属性到默认值
|
|
ver = torch.__version__
|
|
assert ("1.6" in ver) or ("1.9" in ver), f"{RED}Unsupported PyTorch version: {ver}{RESET}"
|
|
|
|
def make_parser():
|
|
parser = argparse.ArgumentParser("ByteTrack Demo!")
|
|
parser.add_argument(
|
|
"--demo", default="video", help="demo type, eg. image, video and webcam"
|
|
)
|
|
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
|
|
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
|
|
|
|
parser.add_argument(
|
|
#"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video"
|
|
"--path", default="./videos/palace.mp4", help="path to images or video"
|
|
)
|
|
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
|
|
parser.add_argument(
|
|
"--save_result",
|
|
action="store_true",
|
|
default=True,
|
|
help="whether to save the inference result of image/video",
|
|
)
|
|
|
|
# exp file
|
|
parser.add_argument(
|
|
"-f",
|
|
"--exp_file",
|
|
default="exps/example/mot/yolox_s_mix_det.py",
|
|
type=str,
|
|
help="pls input your expriment description file",
|
|
)
|
|
parser.add_argument("-c", "--ckpt", default="../weights/bytetrack_s_mot17.pth.tar", type=str, help="ckpt for eval")
|
|
parser.add_argument("--export", default=True, type=bool, help="export this model")
|
|
parser.add_argument(
|
|
"--device",
|
|
default="cpu",
|
|
type=str,
|
|
help="device to run our model, can either be cpu or gpu",
|
|
)
|
|
parser.add_argument("--conf", default=None, type=float, help="test conf")
|
|
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
|
|
parser.add_argument("--tsize", default=None, type=int, help="test img size")
|
|
parser.add_argument("--fps", default=30, type=int, help="frame rate (fps)")
|
|
parser.add_argument(
|
|
"--fp16",
|
|
dest="fp16",
|
|
default=False,
|
|
action="store_true",
|
|
help="Adopting mix precision evaluating.",
|
|
)
|
|
parser.add_argument(
|
|
"--fuse",
|
|
dest="fuse",
|
|
default=False,
|
|
action="store_true",
|
|
help="Fuse conv and bn for testing.",
|
|
)
|
|
parser.add_argument(
|
|
"--trt",
|
|
dest="trt",
|
|
default=False,
|
|
action="store_true",
|
|
help="Using TensorRT model for testing.",
|
|
)
|
|
# tracking args
|
|
parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
|
|
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
|
|
parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
|
|
parser.add_argument(
|
|
"--aspect_ratio_thresh", type=float, default=1.6,
|
|
help="threshold for filtering out boxes of which aspect ratio are above the given value."
|
|
)
|
|
parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
|
|
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
|
|
|
|
parser.add_argument("--trace_path", type=str, default="../2_compile/fmodel/bytetrack_s_608x1088_traced.pt", help="name&path of traced model")
|
|
parser.add_argument("--trace_inputshape",type=int,nargs='+', default=[608,1088], help="input shape :hw")
|
|
return parser
|
|
|
|
FOCUS_WEIGHT = torch.zeros(12,3,2,2)
|
|
for i in range(12):
|
|
j=i%3
|
|
k=(i//3)%2
|
|
l=0 if i<6 else 1
|
|
FOCUS_WEIGHT[i,j,k,l]=1
|
|
FOCUS_BIAS = torch.zeros(12)
|
|
|
|
def focus_forward(self, x):
|
|
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
|
|
x = torch.nn.functional.conv2d(x,FOCUS_WEIGHT,bias=FOCUS_BIAS,stride=2,padding=0,dilation=1,groups=1)
|
|
return self.conv(x)
|
|
|
|
def new_forward(self, xin, labels=None, imgs=None):
|
|
outputs = []
|
|
output_new = []
|
|
origin_preds = []
|
|
x_shifts = []
|
|
y_shifts = []
|
|
expanded_strides = []
|
|
|
|
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
|
|
zip(self.cls_convs, self.reg_convs, self.strides, xin)
|
|
):
|
|
x = self.stems[k](x)
|
|
cls_x = x
|
|
reg_x = x
|
|
|
|
cls_feat = cls_conv(cls_x)
|
|
reg_feat = reg_conv(reg_x)
|
|
|
|
obj_output = self.obj_preds[k](reg_feat)
|
|
reg_output = self.reg_preds[k](reg_feat)
|
|
cls_output = self.cls_preds[k](cls_feat)
|
|
|
|
if self.training:
|
|
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
|
output, grid = self.get_output_and_grid(
|
|
output, k, stride_this_level, xin[0].type()
|
|
)
|
|
x_shifts.append(grid[:, :, 0])
|
|
y_shifts.append(grid[:, :, 1])
|
|
expanded_strides.append(
|
|
torch.zeros(1, grid.shape[1])
|
|
.fill_(stride_this_level)
|
|
.type_as(xin[0])
|
|
)
|
|
if self.use_l1:
|
|
batch_size = reg_output.shape[0]
|
|
hsize, wsize = reg_output.shape[-2:]
|
|
reg_output = reg_output.view(
|
|
batch_size, self.n_anchors, 4, hsize, wsize
|
|
)
|
|
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
|
|
batch_size, -1, 4
|
|
)
|
|
origin_preds.append(reg_output.clone())
|
|
|
|
else:
|
|
output = torch.cat(
|
|
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
|
|
)
|
|
|
|
output_new.append(obj_output)
|
|
output_new.append(reg_output)
|
|
output_new.append(cls_output)
|
|
|
|
outputs.append(output)
|
|
|
|
if self.training:
|
|
return self.get_losses(
|
|
imgs,
|
|
x_shifts,
|
|
y_shifts,
|
|
expanded_strides,
|
|
labels,
|
|
torch.cat(outputs, 1),
|
|
origin_preds,
|
|
dtype=xin[0].dtype,
|
|
)
|
|
else:
|
|
self.hw = [x.shape[-2:] for x in outputs]
|
|
# [batch, n_anchors_all, 85]
|
|
outputs = torch.cat(
|
|
[x.flatten(start_dim=2) for x in outputs], dim=2
|
|
).permute(0, 2, 1)
|
|
self.decode_in_inference=False
|
|
if self.decode_in_inference:
|
|
return self.decode_outputs(outputs, dtype=xin[0].type())
|
|
else:
|
|
return output_new
|
|
|
|
def get_image_list(path):
|
|
image_names = []
|
|
for maindir, subdir, file_name_list in os.walk(path):
|
|
for filename in file_name_list:
|
|
apath = osp.join(maindir, filename)
|
|
ext = osp.splitext(apath)[1]
|
|
if ext in IMAGE_EXT:
|
|
image_names.append(apath)
|
|
return image_names
|
|
|
|
|
|
def write_results(filename, results):
|
|
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
|
|
with open(filename, 'w') as f:
|
|
for frame_id, tlwhs, track_ids, scores in results:
|
|
for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
|
|
if track_id < 0:
|
|
continue
|
|
x1, y1, w, h = tlwh
|
|
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
|
|
f.write(line)
|
|
logger.info('save results to {}'.format(filename))
|
|
|
|
|
|
class Predictor(object):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
exp,
|
|
trt_file=None,
|
|
decoder=None,
|
|
device=torch.device("cpu"),
|
|
fp16=False
|
|
):
|
|
self.model = model
|
|
self.decoder = decoder
|
|
self.num_classes = exp.num_classes
|
|
self.confthre = exp.test_conf
|
|
self.nmsthre = exp.nmsthre
|
|
self.test_size = exp.test_size
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.export = False
|
|
if trt_file is not None:
|
|
from torch2trt import TRTModule
|
|
|
|
model_trt = TRTModule()
|
|
model_trt.load_state_dict(torch.load(trt_file))
|
|
|
|
x = torch.ones((1, 3, exp.test_size[0], exp.test_size[1]), device=device)
|
|
self.model(x)
|
|
self.model = model_trt
|
|
self.rgb_means = (0.485, 0.456, 0.406)
|
|
self.std = (0.229, 0.224, 0.225)
|
|
|
|
def inference(self, img, timer):
|
|
img_info = {"id": 0}
|
|
if isinstance(img, str):
|
|
img_info["file_name"] = osp.basename(img)
|
|
img = cv2.imread(img)
|
|
else:
|
|
img_info["file_name"] = None
|
|
|
|
height, width = img.shape[:2]
|
|
img_info["height"] = height
|
|
img_info["width"] = width
|
|
img_info["raw_img"] = img
|
|
|
|
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
|
|
img_info["ratio"] = ratio
|
|
img = torch.from_numpy(img).unsqueeze(0).float()
|
|
if self.fp16:
|
|
img = img.half() # to FP16
|
|
with torch.no_grad():
|
|
timer.tic()
|
|
if self.export:
|
|
self.model.head.decode_in_inference = False
|
|
# 模型修改
|
|
Focus.__call__ = focus_forward
|
|
YOLOXHead.__call__ = new_forward
|
|
# 模型导出
|
|
dumyin=torch.randn(1,3,args.trace_inputshape[0],args.trace_inputshape[1],dtype=torch.float)
|
|
trcnet = torch.jit.trace(self.model, dumyin)
|
|
trcnet.save(args.trace_path)
|
|
print("successful save model in ", args.trace_path)
|
|
exit()
|
|
####################################################################
|
|
#HookSwitch.hook=True
|
|
|
|
outputs = self.model(img)
|
|
#excel_res_path=R"D:\code\算力带宽_ByteTrack.xlsx"
|
|
#record(excel_res_path, "Bytetrack", [img.size()[2], img.size()[3]])
|
|
####################################################################
|
|
if self.decoder is not None:
|
|
outputs = self.decoder(outputs, dtype=outputs.type())
|
|
outputs = postprocess(
|
|
outputs, self.num_classes, self.confthre, self.nmsthre
|
|
)
|
|
#logger.info("Infer time: {:.4f}s".format(time.time() - t0))
|
|
return outputs, img_info
|
|
|
|
|
|
def image_demo(predictor, vis_folder, current_time, args):
|
|
if osp.isdir(args.path):
|
|
files = get_image_list(args.path)
|
|
else:
|
|
files = [args.path]
|
|
files.sort()
|
|
tracker = BYTETracker(args, frame_rate=args.fps)
|
|
timer = Timer()
|
|
results = []
|
|
|
|
for frame_id, img_path in enumerate(files, 1):
|
|
outputs, img_info = predictor.inference(img_path, timer)
|
|
if outputs[0] is not None:
|
|
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
|
|
online_tlwhs = []
|
|
online_ids = []
|
|
online_scores = []
|
|
for t in online_targets:
|
|
tlwh = t.tlwh
|
|
tid = t.track_id
|
|
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
|
|
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
|
|
online_tlwhs.append(tlwh)
|
|
online_ids.append(tid)
|
|
online_scores.append(t.score)
|
|
# save results
|
|
results.append(
|
|
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
|
|
)
|
|
timer.toc()
|
|
online_im = plot_tracking(
|
|
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time
|
|
)
|
|
else:
|
|
timer.toc()
|
|
online_im = img_info['raw_img']
|
|
|
|
# result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
|
|
if args.save_result:
|
|
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
|
|
save_folder = osp.join(vis_folder, timestamp)
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
cv2.imwrite(osp.join(save_folder, osp.basename(img_path)), online_im)
|
|
|
|
if frame_id % 20 == 0:
|
|
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
|
|
|
|
ch = cv2.waitKey(0)
|
|
if ch == 27 or ch == ord("q") or ch == ord("Q"):
|
|
break
|
|
|
|
if args.save_result:
|
|
res_file = osp.join(vis_folder, f"{timestamp}.txt")
|
|
with open(res_file, 'w') as f:
|
|
f.writelines(results)
|
|
logger.info(f"save results to {res_file}")
|
|
|
|
|
|
def imageflow_demo(predictor, vis_folder, current_time, args):
|
|
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
|
|
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
|
|
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
|
|
save_folder = osp.join(vis_folder, timestamp)
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
if args.demo == "video":
|
|
save_path = osp.join(save_folder, args.path.split("/")[-1])
|
|
else:
|
|
save_path = osp.join(save_folder, "camera.mp4")
|
|
logger.info(f"video save_path is {save_path}")
|
|
vid_writer = cv2.VideoWriter(
|
|
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
|
|
)
|
|
tracker = BYTETracker(args, frame_rate=30)
|
|
timer = Timer()
|
|
frame_id = 0
|
|
results = []
|
|
while True:
|
|
if frame_id % 20 == 0:
|
|
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
|
|
ret_val, frame = cap.read()
|
|
if ret_val:
|
|
outputs, img_info = predictor.inference(frame, timer)
|
|
if outputs[0] is not None:
|
|
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
|
|
online_tlwhs = []
|
|
online_ids = []
|
|
online_scores = []
|
|
for t in online_targets:
|
|
tlwh = t.tlwh
|
|
tid = t.track_id
|
|
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
|
|
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
|
|
online_tlwhs.append(tlwh)
|
|
online_ids.append(tid)
|
|
online_scores.append(t.score)
|
|
results.append(
|
|
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
|
|
)
|
|
timer.toc()
|
|
online_im = plot_tracking(
|
|
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
|
|
)
|
|
else:
|
|
timer.toc()
|
|
online_im = img_info['raw_img']
|
|
if args.save_result:
|
|
vid_writer.write(online_im)
|
|
ch = cv2.waitKey(1)
|
|
if ch == 27 or ch == ord("q") or ch == ord("Q"):
|
|
break
|
|
else:
|
|
break
|
|
frame_id += 1
|
|
|
|
if args.save_result:
|
|
res_file = osp.join(vis_folder, f"{timestamp}.txt")
|
|
with open(res_file, 'w') as f:
|
|
f.writelines(results)
|
|
logger.info(f"save results to {res_file}")
|
|
|
|
|
|
def main(exp, args):
|
|
if not args.experiment_name:
|
|
args.experiment_name = exp.exp_name
|
|
|
|
output_dir = osp.join(exp.output_dir, args.experiment_name)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
if args.save_result:
|
|
vis_folder = osp.join(output_dir, "track_vis")
|
|
os.makedirs(vis_folder, exist_ok=True)
|
|
|
|
if args.trt:
|
|
args.device = "gpu"
|
|
args.device = torch.device("cuda" if args.device == "gpu" else "cpu")
|
|
|
|
logger.info("Args: {}".format(args))
|
|
|
|
if args.conf is not None:
|
|
exp.test_conf = args.conf
|
|
if args.nms is not None:
|
|
exp.nmsthre = args.nms
|
|
if args.tsize is not None:
|
|
exp.test_size = (args.tsize, args.tsize)
|
|
|
|
#model = exp.get_model().to(args.device)
|
|
model = exp.get_model()
|
|
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
|
model.eval()
|
|
|
|
if not args.trt:
|
|
if args.ckpt is None:
|
|
ckpt_file = osp.join(output_dir, "best_ckpt.pth.tar")
|
|
else:
|
|
ckpt_file = args.ckpt
|
|
logger.info("loading checkpoint")
|
|
ckpt = torch.load(ckpt_file, map_location="cpu")
|
|
# load the model state dict
|
|
model.load_state_dict(ckpt["model"])
|
|
logger.info("loaded checkpoint done.")
|
|
|
|
if args.fuse:
|
|
logger.info("\tFusing model...")
|
|
model = fuse_model(model)
|
|
|
|
if args.fp16:
|
|
model = model.half() # to FP16
|
|
|
|
if args.trt:
|
|
assert not args.fuse, "TensorRT model is not support model fusing!"
|
|
trt_file = osp.join(output_dir, "model_trt.pth")
|
|
assert osp.exists(
|
|
trt_file
|
|
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
|
|
model.head.decode_in_inference = False
|
|
decoder = model.head.decode_outputs
|
|
logger.info("Using TensorRT to inference")
|
|
else:
|
|
trt_file = None
|
|
decoder = None
|
|
|
|
predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
|
|
if args.export:
|
|
predictor.export = True
|
|
current_time = time.localtime()
|
|
if args.demo == "image":
|
|
image_demo(predictor, vis_folder, current_time, args)
|
|
elif args.demo == "video" or args.demo == "webcam":
|
|
imageflow_demo(predictor, vis_folder, current_time, args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = make_parser().parse_args()
|
|
exp = get_exp(args.exp_file, args.name)
|
|
|
|
main(exp, args)
|