mdz/pytorch/A2Net/1_scripts/2_save_infer.py

227 lines
9.1 KiB
Python

import sys
sys.path.insert(0, '../0_A2Net/')
from models.model import BaseNet
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.nn.parallel import gather
import torch.optim.lr_scheduler
import dataset as myDataLoader
import Transforms as myTransforms
from metric_tool import ConfuseMatrixMeter
from PIL import Image
import os, time
import numpy as np
from argparse import ArgumentParser
import onnx
import onnxruntime
def BCEDiceLoss(inputs, targets):
# print(inputs.shape, targets.shape)
bce = F.binary_cross_entropy(inputs, targets)
inter = (inputs * targets).sum()
eps = 1e-5
dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
# print(bce.item(), inter.item(), inputs.sum().item(), dice.item())
return bce + 1 - dice
def BCE(inputs, targets):
# print(inputs.shape, targets.shape)
bce = F.binary_cross_entropy(inputs, targets)
return bce
@torch.no_grad()
def val(args, val_loader, model, epoch):
# model.eval()
salEvalVal = ConfuseMatrixMeter(n_class=2)
epoch_loss = []
total_batches = len(val_loader)
print(len(val_loader))
for iter, batched_inputs in enumerate(val_loader):
img, target = batched_inputs
img_name = val_loader.sampler.data_source.file_list[iter]
pre_img = img[:, 0:3]
post_img = img[:, 3:6]
start_time = time.time()
if args.onGPU == True:
pre_img = pre_img.cuda()
target = target.cuda()
post_img = post_img.cuda()
pre_img_var = torch.autograd.Variable(pre_img).float()
post_img_var = torch.autograd.Variable(post_img).float()
target_var = torch.autograd.Variable(target).float()
# run the model
# output, output2, output3, output4 = model(pre_img_var, post_img_var)
# run the onnx model
ort_session = onnxruntime.InferenceSession(args.model_onnx)
ort_inputs1 = ort_session.get_inputs()[0].name
ort_inputs2 = ort_session.get_inputs()[1].name
ort_outs = ort_session.run(None, {ort_inputs1:pre_img_var.numpy(),ort_inputs2:post_img_var.numpy()})
output = torch.tensor(ort_outs[0])
output2 = torch.tensor(ort_outs[1])
output3 = torch.tensor(ort_outs[2])
output4 = torch.tensor(ort_outs[3])
loss = BCEDiceLoss(output, target_var) + BCEDiceLoss(output2, target_var) + BCEDiceLoss(output3, target_var) + \
BCEDiceLoss(output4, target_var)
pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long()
# torch.cuda.synchronize()
time_taken = time.time() - start_time
epoch_loss.append(loss.data.item())
# compute the confusion matrix
if args.onGPU and torch.cuda.device_count() > 1:
output = gather(pred, 0, dim=0)
# save change maps
pr = pred[0, 0].cpu().numpy()
gt = target_var[0, 0].cpu().numpy()
index_tp = np.where(np.logical_and(pr == 1, gt == 1))
index_fp = np.where(np.logical_and(pr == 1, gt == 0))
index_tn = np.where(np.logical_and(pr == 0, gt == 0))
index_fn = np.where(np.logical_and(pr == 0, gt == 1))
#
map = np.zeros([gt.shape[0], gt.shape[1], 3])
map[index_tp] = [255, 255, 255] # white
map[index_fp] = [255, 0, 0] # red
map[index_tn] = [0, 0, 0] # black
map[index_fn] = [0, 255, 255] # Cyan
change_map = Image.fromarray(np.array(map, dtype=np.uint8))
change_map.save(args.vis_dir + img_name)
f1 = salEvalVal.update_cm(pr, gt)
if iter % 5 == 0:
print('\r[%d/%d] F1: %3f loss: %.3f time: %.3f' % (iter, total_batches, f1, loss.data.item(), time_taken),
end='')
# break #只跑一张图片
average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)
scores = salEvalVal.get_scores()
return average_epoch_loss_val, scores
def ValidateSegmentation(args):
torch.backends.cudnn.benchmark = True
SEED = 2333
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
model = BaseNet(3, 1)
if args.file_root == 'LEVIR':
args.file_root = '../0_A2Net/samples'
# args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/LEVIR-CD_256_patches'
elif args.file_root == 'BCDD':
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/BCDD'
elif args.file_root == 'SYSU':
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/SYSU'
elif args.file_root == 'CDD':
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/CDD'
elif args.file_root == 'testLEVIR':
args.file_root = '../0_A2Net/samples'
else:
raise TypeError('%s has not defined' % args.file_root)
if not os.path.exists(args.vis_dir):
os.makedirs(args.vis_dir)
if args.onGPU:
model = model.cuda()
total_params = sum([np.prod(p.size()) for p in model.parameters()])
print('Total network parameters (excluding idr): ' + str(total_params))
mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485]
std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229]
# compose the data with transforms
valDataset = myTransforms.Compose([
myTransforms.Normalize(mean=mean, std=std),
myTransforms.Scale(args.inWidth, args.inHeight),
myTransforms.ToTensor()
])
test_data = myDataLoader.Dataset("test", file_root=args.file_root, transform=valDataset)
testLoader = torch.utils.data.DataLoader(
test_data, shuffle=False,
batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=False)
if args.onGPU:
cudnn.benchmark = True
logFileLoc = args.savedir + args.logFile
if os.path.isfile(logFileLoc):
logger = open(logFileLoc, 'a')
else:
logger = open(logFileLoc, 'w')
logger.write("Parameters: %s" % (str(total_params)))
logger.write(
"\n%s\t%s\t%s\t%s\t%s\t%s" % ('Epoch', 'Kappa', 'IoU', 'F1', 'R', 'P'))
logger.flush()
# load the model
model = onnx.load(args.model_onnx)
print("load model in ", args.model_onnx)
loss_test, score_test = val(args, testLoader, model, 0)
print("\nTest :\t Kappa (te) = %.4f\t IoU (te) = %.4f\t F1 (te) = %.4f\t R (te) = %.4f\t P (te) = %.4f" \
% (score_test['Kappa'], score_test['IoU'], score_test['F1'], score_test['recall'], score_test['precision']))
logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % ('Test', score_test['Kappa'], score_test['IoU'],
score_test['F1'], score_test['recall'],
score_test['precision']))
logger.flush()
logger.close()
import scipy.io as scio
scio.savemat(args.vis_dir + 'results.mat', score_test)
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--file_root', default="LEVIR", help='Data directory | LEVIR | BCDD | SYSU ')
parser.add_argument('--inWidth', type=int, default=256, help='Width of RGB image')
parser.add_argument('--inHeight', type=int, default=256, help='Height of RGB image')
parser.add_argument('--max_steps', type=int, default=40000, help='Max. number of iterations')
parser.add_argument('--num_workers', type=int, default=3, help='No. of parallel threads')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--step_loss', type=int, default=100, help='Decrease learning rate after how many epochs')
parser.add_argument('--lr', type=float, default=5e-4, help='Initial learning rate')
parser.add_argument('--lr_mode', default='poly', help='Learning rate policy, step or poly')
parser.add_argument('--savedir', default='../0_A2Net/results', help='Directory to save the results')
parser.add_argument('--resume', default=None, help='Use this checkpoint to continue training | '
'./results_ep100/checkpoint.pth.tar')
parser.add_argument('--logFile', default='testLog.txt',
help='File that stores the training and validation logs')
parser.add_argument('--onGPU', default=False, type=lambda x: (str(x).lower() == 'true'),
help='Run on CPU or GPU. If TRUE, then GPU.')
parser.add_argument('--weight', default='', type=str, help='pretrained weight, can be a non-strict copy')
parser.add_argument('--ms', type=int, default=0, help='apply multi-scale training, default False')
parser.add_argument('--model_onnx', default='../2_compile/fmodel/A2Net_256x256_traced.onnx', help='onnx model path')#export model path
parser.add_argument('--vis_dir', default='./Predict/LEVIR/save_infer/', help='save the visualize results path')
args = parser.parse_args()
print('Called with args:')
print(args)
ValidateSegmentation(args)