mdz/pytorch/PolyLaneNet/1_scripts/0_infer.py

168 lines
5.9 KiB
Python

# python .\0_infer.py --cfg .\config.yaml --exp_name tusimple --view
import os
import sys
import random
import logging
import argparse
import subprocess
from time import time
import cv2
import numpy as np
import torch
import sys
sys.path.append(R"../0_PolyLaneNet")
from lib.config import Config
from utils.evaluator import Evaluator
def test(model, test_loader, evaluator, exp_root, cfg, view, epoch, max_batches=None, verbose=True):
if verbose:
logging.info("Starting testing.")
# Test the model
if epoch > 0:
model.load_state_dict(torch.load(os.path.join(exp_root, "models", "model_{:03d}.pt".format(epoch)),map_location=torch.device('cpu'))['model'])
model.eval()
criterion_parameters = cfg.get_loss_parameters()
test_parameters = cfg.get_test_parameters()
criterion = model.loss
loss = 0
total_iters = 0
test_t0 = time()
loss_dict = {}
with torch.no_grad():
for idx, (images, labels, img_idxs) in enumerate(test_loader):
if max_batches is not None and idx >= max_batches:
break
if idx % 1 == 0 and verbose:
logging.info("Testing iteration: {}/{}".format(idx + 1, len(test_loader)))
images = images.to(device)
labels = labels.to(device)
t0 = time()
outputs = model(images)
t = time() - t0
loss_i, loss_dict_i = criterion(outputs, labels, **criterion_parameters)
loss += loss_i.item()
total_iters += 1
for key in loss_dict_i:
if key not in loss_dict:
loss_dict[key] = 0
loss_dict[key] += loss_dict_i[key]
extra_output = outputs[1]
outputs = model.decode(outputs, labels, **test_parameters)
if evaluator is not None:
lane_outputs, _ = outputs
evaluator.add_prediction(img_idxs, lane_outputs.cpu().numpy(), t / images.shape[0])
if view:
outputs, extra_outputs = outputs
# print('outputs =',outputs,type(outputs),outputs[0].shape)
preds = test_loader.dataset.draw_annotation(
idx,
pred=outputs[0].cpu().numpy(),
cls_pred=extra_outputs[0].cpu().numpy() if extra_outputs is not None else None)
cv2.imshow('pred', preds)
cv2.waitKey(0)
if verbose:
logging.info("Testing time: {:.4f}".format(time() - test_t0))
out_line = []
for key in loss_dict:
loss_dict[key] /= total_iters
out_line.append('{}: {:.4f}'.format(key, loss_dict[key]))
if verbose:
logging.info(', '.join(out_line))
return evaluator, loss / total_iters
def get_code_state():
state = "Git hash: {}".format(
subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
state += '\n*************\nGit diff:\n*************\n'
state += subprocess.run(['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')
return state
def log_on_exception(exc_type, exc_value, exc_traceback):
logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Lane regression")
parser.add_argument("--exp_name", default="tusimple", help="Experiment name")
parser.add_argument("--cfg", default="config.yaml", help="Config file")
parser.add_argument("--epoch", type=int, default=2695, help="Epoch to test the model on")
parser.add_argument("--batch_size", type=int, help="Number of images per batch")
parser.add_argument("--view", type=bool,default = True, help="Show predictions")
args = parser.parse_args()
cfg = Config(args.cfg)
# Set up seeds
torch.manual_seed(cfg['seed'])
np.random.seed(cfg['seed'])
random.seed(cfg['seed'])
# Set up logging
exp_root = os.path.join(cfg['exps_dir'], os.path.basename(os.path.normpath(args.exp_name)))
logging.basicConfig(
format="[%(asctime)s] [%(levelname)s] %(message)s",
level=logging.INFO,
handlers=[
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
logging.StreamHandler(),
],
)
sys.excepthook = log_on_exception
logging.info("Experiment name: {}".format(args.exp_name))
logging.info("Config:\n" + str(cfg))
logging.info("Args:\n" + str(args))
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Hyper parameters
num_epochs = cfg["epochs"]
batch_size = cfg["batch_size"] if args.batch_size is None else args.batch_size
# Model
model = cfg.get_model().to(device)
test_epoch = args.epoch
# Get data set
test_dataset = cfg.get_dataset("test")
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size if args.view is False else 1,
shuffle=False,
num_workers=1)
# Eval results
evaluator = Evaluator(test_loader.dataset, exp_root)
logging.basicConfig(
format="[%(asctime)s] [%(levelname)s] %(message)s",
level=logging.INFO,
handlers=[
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
logging.StreamHandler(),
],
)
_, mean_loss = test(model, test_loader, evaluator, exp_root, cfg, epoch=test_epoch, view=args.view)
logging.info("Mean test loss: {:.4f}".format(mean_loss))
evaluator.exp_name = args.exp_name
eval_str, _ = evaluator.eval(label='{}_{}'.format(os.path.basename(args.exp_name), test_epoch))
logging.info(eval_str)