130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import os, sys
|
|
sys.path.append(R"../0_GazeTR")
|
|
base_dir = os.getcwd()
|
|
sys.path.insert(0, base_dir)
|
|
import model
|
|
import importlib
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import cv2, yaml, copy
|
|
from easydict import EasyDict as edict
|
|
import ctools, gtools
|
|
import argparse
|
|
|
|
|
|
|
|
def parser_args():
|
|
parser = argparse.ArgumentParser(description='Pytorch Basic Model Training')
|
|
parser.add_argument('--weight', default="../weights/Iter_80_trans6.pt",type=str)
|
|
parser.add_argument('-t', '--target', type=str, default="config/test/config_mpii_test.yaml",
|
|
help = 'config path about test')
|
|
parser.add_argument('-p', '--person', type=int, default=0,
|
|
help = 'the num of subject for test')
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main(test):
|
|
|
|
# ===============================> Setup <============================
|
|
reader = importlib.import_module("reader." + test.reader)
|
|
data = test.data
|
|
load = test.load
|
|
|
|
# ==============================> Read Data <========================
|
|
data, folder = ctools.readfolder(data, [test.person])
|
|
|
|
testname = folder[test.person]
|
|
|
|
dataset = reader.loader(data, 500, num_workers=4, shuffle=True)
|
|
modelpath = args.weight
|
|
logpath = os.path.join("log/",f'{test.savename}/{testname}')
|
|
|
|
if not os.path.exists(logpath):
|
|
os.makedirs(logpath)
|
|
|
|
# =============================> Test <==============================
|
|
|
|
begin = load.begin_step; end = load.end_step; step = load.steps
|
|
|
|
for saveiter in range(begin, end+step, step):
|
|
print(f"Test {saveiter}")
|
|
|
|
# ----------------------Load Model------------------------------
|
|
net = model.Model()
|
|
|
|
|
|
statedict = torch.load(
|
|
modelpath,
|
|
map_location="cpu"
|
|
)
|
|
|
|
|
|
net; net.load_state_dict(statedict); net.eval()
|
|
|
|
length = len(dataset); accs = 0; count = 0
|
|
|
|
# -----------------------Open log file--------------------------------
|
|
logname = f"{saveiter}.log"
|
|
|
|
outfile = open(os.path.join(logpath, logname), 'w')
|
|
outfile.write("name results gts\n")
|
|
|
|
|
|
# -------------------------Testing---------------------------------
|
|
with torch.no_grad():
|
|
|
|
for j, (data, label) in enumerate(dataset):
|
|
|
|
for key in data:
|
|
if key != 'name': data[key] = data[key]#.cuda()
|
|
|
|
names = data["name"]
|
|
|
|
gts = label
|
|
# 前向
|
|
gazes = net(data)
|
|
|
|
for k, gaze in enumerate(gazes):
|
|
|
|
gaze = gaze.cpu().detach().numpy()
|
|
gt = gts.numpy()[k]
|
|
|
|
count += 1
|
|
accs += gtools.angular(
|
|
gtools.gazeto3d(gaze),
|
|
gtools.gazeto3d(gt)
|
|
)
|
|
|
|
name = [names[k]]
|
|
gaze = [str(u) for u in gaze]
|
|
gt = [str(u) for u in gt]
|
|
log = name + [",".join(gaze)] + [",".join(gt)]
|
|
outfile.write(" ".join(log) + "\n")
|
|
|
|
loger = f"[{saveiter}] Total Num: {count}, avg: {accs/count}"
|
|
outfile.write(loger)
|
|
print(loger)
|
|
outfile.close()
|
|
|
|
if __name__ == "__main__":
|
|
args = parser_args()
|
|
# Read model from train config and Test data in test config.
|
|
# train_conf = edict(yaml.load(open(args.source), Loader=yaml.FullLoader))
|
|
test_conf = edict(yaml.load(open(args.target), Loader=yaml.FullLoader))
|
|
test_conf = test_conf.test
|
|
test_conf.person = args.person
|
|
# print("=======================>(Begin) Config of training<======================")
|
|
# print(ctools.DictDumps(train_conf))
|
|
# print("=======================>(End) Config of training<======================")
|
|
# print("")
|
|
print("=======================>(Begin) Config for test<======================")
|
|
print(ctools.DictDumps(test_conf))
|
|
print("=======================>(End) Config for test<======================")
|
|
# main(train_conf.train, test_conf)
|
|
main(test_conf)
|
|
|