mdz/pytorch/ETTrack/1_scripts/2_save_infer.py

274 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import time
from collections import OrderedDict
import types
prj_path = os.path.join(os.path.dirname(__file__), '..')
if prj_path not in sys.path:
sys.path.append(prj_path)
from pytracking.evaluation import get_dataset
from pytracking.evaluation.running import _save_tracker_output
from pytracking.evaluation import Tracker
import torch
import numpy as np
from tracking.basic_model.et_tracker import ET_Tracker
from pytracking.tracker.et_tracker.et_tracker import TransconverTracker
def ET_Tracker_template(self, z):
'''
Used during the tracking -> computes the embedding of the target in the first frame.
'''
# self.zf = self.backbone_net(z)
# 替换为加载trace出来的模型net1
net1 = torch.jit.load(NET1_PATH)
global zf
zf = net1(z)#存为全局变量
def ET_Tracker_forward(self, x, zf):
# [1,3,288,288]
xf = self.backbone_net(x)
# [1,96,16,16]
# Batch Normalization before Corr
# ICRAFT NOTE:
# 为了部署,将成员变量作为前向输入
# zf, xf = self.neck(self.zf, xf) #[1,96,8,8] [1,96,16,16]<-[1,96,8,8] [1,96,16,16]
zf, xf = self.neck(zf, xf) #[1,96,8,8] [1,96,16,16]<-[1,96,8,8] [1,96,16,16]
# ICRAFT NOTE:
# 不支持字典传递数据将feature_fusor返回数据展开
# pixelwise correlation
# feat_dict = self.feature_fusor(zf, xf) # cls:[1,128,16,16],[1,128,16,16]<-[1,96,8,8] [1,96,16,16]
feat_cls, feat_reg = self.feature_fusor(zf, xf) # cls:[1,128,16,16],[1,128,16,16]<-[1,96,8,8] [1,96,16,16]
c = self.cls_branch_1(feat_cls)#(feat_dict['cls'])
c = self.cls_branch_2(c)
c = self.cls_branch_3(c)
c = self.cls_branch_4(c)
c = self.cls_branch_5(c)
c = self.cls_branch_6(c)
c = self.cls_pred_head(c) # [1,1,16,16]
b = self.bbreg_branch_1(feat_reg)#(feat_dict['reg'])
b = self.bbreg_branch_2(b)
b = self.bbreg_branch_3(b)
b = self.bbreg_branch_4(b)
b = self.bbreg_branch_5(b)
b = self.bbreg_branch_6(b)
b = self.bbreg_branch_7(b)
b = self.bbreg_branch_8(b)
b = self.reg_pred_head(b) # [1,4,16,16]
# oup = {}
# oup['cls'] = c
# oup['reg'] = b
# ICRAFT NOTE:
# 不支持字典传递数据,直接返回
# return oup['cls'], oup['reg']
return c, b
ET_Tracker.template = ET_Tracker_template
ET_Tracker.forward = ET_Tracker_forward
def TransconverTracker_update(self, x_crops, target_pos, target_sz, window, scale_z, p, debug=False, writer=None):
# cls_score, bbox_pred = self.net.track(x_crops.to(self.params.device))
# 替换为加载trace出来的模型net2
net2 = torch.jit.load(NET2_PATH)
cls_score, bbox_pred = net2(x_crops.to(self.params.device),zf)
cls_score = torch.sigmoid(cls_score).squeeze().cpu().data.numpy()
# bbox to real predict
bbox_pred = bbox_pred.squeeze().cpu().data.numpy()
pred_x1 = self.grid_to_search_x - bbox_pred[0, ...]
pred_y1 = self.grid_to_search_y - bbox_pred[1, ...]
pred_x2 = self.grid_to_search_x + bbox_pred[2, ...]
pred_y2 = self.grid_to_search_y + bbox_pred[3, ...]
# size penalty
s_c = self.change(self.sz(pred_x2 - pred_x1, pred_y2 - pred_y1) / (self.sz_wh(target_sz))) # scale penalty
r_c = self.change((target_sz[0] / target_sz[1]) / ((pred_x2 - pred_x1) / (pred_y2 - pred_y1))) # ratio penalty
penalty = np.exp(-(r_c * s_c - 1) * p.penalty_k)
pscore = penalty * cls_score
# window penalty
pscore = pscore * (1 - p.window_influence) + window * p.window_influence
# get max
r_max, c_max = np.unravel_index(pscore.argmax(), pscore.shape)
# to real size
pred_x1 = pred_x1[r_max, c_max]
pred_y1 = pred_y1[r_max, c_max]
pred_x2 = pred_x2[r_max, c_max]
pred_y2 = pred_y2[r_max, c_max]
pred_xs = (pred_x1 + pred_x2) / 2
pred_ys = (pred_y1 + pred_y2) / 2
pred_w = pred_x2 - pred_x1
pred_h = pred_y2 - pred_y1
diff_xs = pred_xs - p.instance_size // 2
diff_ys = pred_ys - p.instance_size // 2
diff_xs, diff_ys, pred_w, pred_h = diff_xs / scale_z, diff_ys / scale_z, pred_w / scale_z, pred_h / scale_z
target_sz = target_sz / scale_z
# size learning rate
lr = penalty[r_max, c_max] * cls_score[r_max, c_max] * p.lr
# size rate
res_xs = target_pos[0] + diff_xs
res_ys = target_pos[1] + diff_ys
res_w = pred_w * lr + (1 - lr) * target_sz[0]
res_h = pred_h * lr + (1 - lr) * target_sz[1]
target_pos = np.array([res_xs, res_ys])
target_sz = target_sz * (1 - lr) + lr * np.array([res_w, res_h])
if debug:
return target_pos, target_sz, cls_score[r_max, c_max], cls_score
else:
return target_pos, target_sz, cls_score[r_max, c_max]
TransconverTracker.update =TransconverTracker_update
# 全局变量
TRACE_PATH = "../2_compile/fmodel/"
NET1_PATH = TRACE_PATH+"ettrack_net1_1x3x127x127_traced.pt"
NET2_PATH = TRACE_PATH+"ettrack_net2_1x3x288x288_traced.pt"
zf = torch.zeros(1,96,8,8)
if __name__ == '__main__':
dataset_name = 'lasot'
tracker_name = 'et_tracker'
tracker_param = 'et_tracker'
visualization=None
debug=None
visdom_info=None
run_id = 2405101502
dataset = get_dataset(dataset_name)
tracker = Tracker(tracker_name, tracker_param, run_id)
params = tracker.get_parameters()
visualization_ = visualization
debug_ = debug
if debug is None:
debug_ = getattr(params, 'debug', 0)
if visualization is None:
if debug is None:
visualization_ = getattr(params, 'visualization', False)
else:
visualization_ = True if debug else False
params.visualization = visualization_
params.debug = debug_
for seq in dataset[:]:
print(seq)
def _results_exist():
if seq.dataset == 'oxuva':
vid_id, obj_id = seq.name.split('_')[:2]
pred_file = os.path.join(tracker.results_dir, '{}_{}.csv'.format(vid_id, obj_id))
return os.path.isfile(pred_file)
elif seq.object_ids is None:
bbox_file = '{}/{}.txt'.format(tracker.results_dir, seq.name)
return os.path.isfile(bbox_file)
else:
bbox_files = ['{}/{}_{}.txt'.format(tracker.results_dir, seq.name, obj_id) for obj_id in seq.object_ids]
missing = [not os.path.isfile(f) for f in bbox_files]
return sum(missing) == 0
visdom_info = {} if visdom_info is None else visdom_info
if _results_exist() and not debug:
print('FPS: {}'.format(-1))
continue
print('Tracker: {} {} {} , Sequence: {}'.format(tracker.name, tracker.parameter_name, tracker.run_id, seq.name))
tracker._init_visdom(visdom_info, debug_)
if visualization_ and tracker.visdom is None:
tracker.init_visualization()
# Get init information
init_info = seq.init_info()
et_tracker = tracker.create_tracker(params)
output = {'target_bbox': [],
'time': [],
'segmentation': [],
'object_presence_score': []}
def _store_outputs(tracker_out: dict, defaults=None):
defaults = {} if defaults is None else defaults
for key in output.keys():
val = tracker_out.get(key, defaults.get(key, None))
if key in tracker_out or val is not None:
output[key].append(val)
# Initialize
image = tracker._read_image(seq.frames[0])
if et_tracker.params.visualization and tracker.visdom is None:
tracker.visualize(image, init_info.get('init_bbox'))
start_time = time.time()
out = et_tracker.initialize(image, init_info)
if out is None:
out = {}
prev_output = OrderedDict(out)
init_default = {'target_bbox': init_info.get('init_bbox'),
'time': time.time() - start_time,
'segmentation': init_info.get('init_mask'),
'object_presence_score': 1.}
_store_outputs(out, init_default)
for frame_num, frame_path in enumerate(seq.frames[1:], start=1):
image = tracker._read_image(frame_path)
start_time = time.time()
info = seq.frame_info(frame_num)
info['previous_output'] = prev_output
out = et_tracker.track(image, info)
prev_output = OrderedDict(out)
_store_outputs(out, {'time': time.time() - start_time})
segmentation = out['segmentation'] if 'segmentation' in out else None
if tracker.visdom is not None:
tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation)
elif et_tracker.params.visualization:
tracker.visualize(image, out['target_bbox'], segmentation)
for key in ['target_bbox', 'segmentation']:
if key in output and len(output[key]) <= 1:
output.pop(key)
output['image_shape'] = image.shape[:2]
output['object_presence_score_threshold'] = et_tracker.params.get('object_presence_score_threshold', 0.55)
sys.stdout.flush()
if isinstance(output['time'][0], (dict, OrderedDict)):
exec_time = sum([sum(times.values()) for times in output['time']])
num_frames = len(output['time'])
else:
exec_time = sum(output['time'])
num_frames = len(output['time'])
print('FPS: {}'.format(num_frames / exec_time))
if not debug:
_save_tracker_output(seq, tracker, output)