mdz/pytorch/MPIIFaceGaze/1_scripts/1_save_onnx.py

53 lines
1.9 KiB
Python

#!/usr/bin/env python
import os
import argparse
import torch
from gaze_estimation import create_model, get_default_config
RED = '\033[31m' # 设置前景色为红色
RESET = '\033[0m' # 重置所有属性到默认值
ver = torch.__version__
assert ("2.0.1" in ver) or ("1.9" in ver), f"{RED}Unsupported PyTorch version: {ver}{RESET}"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="config/demo_mpiifacegaze_resnet_simple_14.yaml",type=str)
parser.add_argument('--weight', default="../weights/mpiifacegaze_resnet_simple.pth",type=str)
# parser.add_argument('--weight', default="../weights/checkpoint_0015.pth",type=str) # metric test
parser.add_argument('--output-path', '-o', default="../2_compile/fmodel/", type=str)
parser.add_argument('--trace_model', default="MPIIFaceGaze_1x3x224x224.onnx", type=str)
args = parser.parse_args()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
config = get_default_config()
config.merge_from_file(args.config)
device = torch.device(config.device)
model = create_model(config)
if args.weight is not None:
checkpoint = torch.load(args.weight, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
if config.mode == 'MPIIGaze':
x = torch.zeros((1, 1, 36, 60), dtype=torch.float32, device=device)
y = torch.zeros((1, 2), dtype=torch.float32, device=device)
data = (x, y)
elif config.mode == 'MPIIFaceGaze':
x = torch.zeros((1, 3, 224, 224), dtype=torch.float32, device=device)
data = (x, )
else:
raise ValueError
trace_model = args.output_path + args.trace_model
torch.onnx.export(model, data, trace_model, opset_version=11)
print("successful export model to ", trace_model)
if __name__ == '__main__':
main()