222 lines
8.7 KiB
Python
222 lines
8.7 KiB
Python
import sys
|
|
sys.path.insert(0, '../0_A2Net/')
|
|
from models.model import BaseNet
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.backends.cudnn as cudnn
|
|
from torch.nn.parallel import gather
|
|
import torch.optim.lr_scheduler
|
|
|
|
import dataset as myDataLoader
|
|
import Transforms as myTransforms
|
|
from metric_tool import ConfuseMatrixMeter
|
|
from PIL import Image
|
|
|
|
import os, time
|
|
import numpy as np
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
def BCEDiceLoss(inputs, targets):
|
|
# print(inputs.shape, targets.shape)
|
|
bce = F.binary_cross_entropy(inputs, targets)
|
|
inter = (inputs * targets).sum()
|
|
eps = 1e-5
|
|
dice = (2 * inter + eps) / (inputs.sum() + targets.sum() + eps)
|
|
# print(bce.item(), inter.item(), inputs.sum().item(), dice.item())
|
|
return bce + 1 - dice
|
|
|
|
|
|
def BCE(inputs, targets):
|
|
# print(inputs.shape, targets.shape)
|
|
bce = F.binary_cross_entropy(inputs, targets)
|
|
return bce
|
|
|
|
|
|
@torch.no_grad()
|
|
def val(args, val_loader, model, epoch):
|
|
model.eval()
|
|
|
|
salEvalVal = ConfuseMatrixMeter(n_class=2)
|
|
|
|
epoch_loss = []
|
|
|
|
total_batches = len(val_loader)
|
|
print(len(val_loader))
|
|
for iter, batched_inputs in enumerate(val_loader):
|
|
|
|
img, target = batched_inputs
|
|
img_name = val_loader.sampler.data_source.file_list[iter]
|
|
pre_img = img[:, 0:3]
|
|
post_img = img[:, 3:6]
|
|
|
|
start_time = time.time()
|
|
|
|
if args.onGPU == True:
|
|
pre_img = pre_img.cuda()
|
|
target = target.cuda()
|
|
post_img = post_img.cuda()
|
|
|
|
pre_img_var = torch.autograd.Variable(pre_img).float()
|
|
post_img_var = torch.autograd.Variable(post_img).float()
|
|
target_var = torch.autograd.Variable(target).float()
|
|
|
|
# run the mdoel
|
|
output, output2, output3, output4 = model(pre_img_var, post_img_var)
|
|
loss = BCEDiceLoss(output, target_var) + BCEDiceLoss(output2, target_var) + BCEDiceLoss(output3, target_var) + \
|
|
BCEDiceLoss(output4, target_var)
|
|
|
|
pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long()
|
|
|
|
# torch.cuda.synchronize()
|
|
time_taken = time.time() - start_time
|
|
|
|
epoch_loss.append(loss.data.item())
|
|
|
|
# compute the confusion matrix
|
|
if args.onGPU and torch.cuda.device_count() > 1:
|
|
output = gather(pred, 0, dim=0)
|
|
|
|
# save change maps
|
|
pr = pred[0, 0].cpu().numpy()
|
|
gt = target_var[0, 0].cpu().numpy()
|
|
index_tp = np.where(np.logical_and(pr == 1, gt == 1))
|
|
index_fp = np.where(np.logical_and(pr == 1, gt == 0))
|
|
index_tn = np.where(np.logical_and(pr == 0, gt == 0))
|
|
index_fn = np.where(np.logical_and(pr == 0, gt == 1))
|
|
#
|
|
map = np.zeros([gt.shape[0], gt.shape[1], 3])
|
|
map[index_tp] = [255, 255, 255] # white
|
|
map[index_fp] = [255, 0, 0] # red
|
|
map[index_tn] = [0, 0, 0] # black
|
|
map[index_fn] = [0, 255, 255] # Cyan
|
|
|
|
change_map = Image.fromarray(np.array(map, dtype=np.uint8))
|
|
change_map.save(args.vis_dir + img_name)
|
|
|
|
f1 = salEvalVal.update_cm(pr, gt)
|
|
|
|
if iter % 5 == 0:
|
|
print('\r[%d/%d] F1: %3f loss: %.3f time: %.3f' % (iter, total_batches, f1, loss.data.item(), time_taken),
|
|
end='')
|
|
|
|
average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)
|
|
scores = salEvalVal.get_scores()
|
|
|
|
return average_epoch_loss_val, scores
|
|
|
|
|
|
def ValidateSegmentation(args):
|
|
torch.backends.cudnn.benchmark = True
|
|
SEED = 2333
|
|
torch.manual_seed(SEED)
|
|
torch.cuda.manual_seed(SEED)
|
|
|
|
model = BaseNet(3, 1)
|
|
|
|
args.savedir = args.savedir + '_' + args.file_root + '_iter_' + str(args.max_steps) + '_lr_' + str(args.lr) + '/'
|
|
|
|
if args.file_root == 'LEVIR':
|
|
args.file_root = '../0_A2Net/samples'
|
|
# args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/LEVIR-CD_256_patches'
|
|
elif args.file_root == 'BCDD':
|
|
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/BCDD'
|
|
elif args.file_root == 'SYSU':
|
|
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/SYSU'
|
|
elif args.file_root == 'CDD':
|
|
args.file_root = '/home/guan/Documents/Datasets/ChangeDetection/CDD'
|
|
elif args.file_root == 'testLEVIR':
|
|
args.file_root = '../0_A2Net/samples'
|
|
else:
|
|
raise TypeError('%s has not defined' % args.file_root)
|
|
|
|
if not os.path.exists(args.savedir):
|
|
os.makedirs(args.savedir)
|
|
|
|
if not os.path.exists(args.vis_dir):
|
|
os.makedirs(args.vis_dir)
|
|
|
|
if args.onGPU:
|
|
model = model.cuda()
|
|
|
|
total_params = sum([np.prod(p.size()) for p in model.parameters()])
|
|
print('Total network parameters (excluding idr): ' + str(total_params))
|
|
|
|
mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485]
|
|
std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229]
|
|
|
|
# compose the data with transforms
|
|
valDataset = myTransforms.Compose([
|
|
myTransforms.Normalize(mean=mean, std=std),
|
|
myTransforms.Scale(args.inWidth, args.inHeight),
|
|
myTransforms.ToTensor()
|
|
])
|
|
|
|
test_data = myDataLoader.Dataset("test", file_root=args.file_root, transform=valDataset)
|
|
testLoader = torch.utils.data.DataLoader(
|
|
test_data, shuffle=False,
|
|
batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=False)
|
|
|
|
if args.onGPU:
|
|
cudnn.benchmark = True
|
|
|
|
logFileLoc = args.savedir + args.logFile
|
|
if os.path.isfile(logFileLoc):
|
|
logger = open(logFileLoc, 'a')
|
|
else:
|
|
logger = open(logFileLoc, 'w')
|
|
logger.write("Parameters: %s" % (str(total_params)))
|
|
logger.write(
|
|
"\n%s\t%s\t%s\t%s\t%s\t%s" % ('Epoch', 'Kappa', 'IoU', 'F1', 'R', 'P'))
|
|
logger.flush()
|
|
|
|
# load the model
|
|
model_file_name = args.savedir + 'best_model.pth'
|
|
# state_dict = torch.load(model_file_name)
|
|
state_dict = torch.load(model_file_name, map_location=torch.device('cpu'))
|
|
model.load_state_dict(state_dict)
|
|
|
|
loss_test, score_test = val(args, testLoader, model, 0)
|
|
print("\nTest :\t Kappa (te) = %.4f\t IoU (te) = %.4f\t F1 (te) = %.4f\t R (te) = %.4f\t P (te) = %.4f" \
|
|
% (score_test['Kappa'], score_test['IoU'], score_test['F1'], score_test['recall'], score_test['precision']))
|
|
logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % ('Test', score_test['Kappa'], score_test['IoU'],
|
|
score_test['F1'], score_test['recall'],
|
|
score_test['precision']))
|
|
logger.flush()
|
|
logger.close()
|
|
|
|
import scipy.io as scio
|
|
|
|
scio.savemat(args.vis_dir + 'results.mat', score_test)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = ArgumentParser()
|
|
parser.add_argument('--file_root', default="LEVIR", help='Data directory | LEVIR | BCDD | SYSU ')
|
|
parser.add_argument('--inWidth', type=int, default=256, help='Width of RGB image')
|
|
parser.add_argument('--inHeight', type=int, default=256, help='Height of RGB image')
|
|
parser.add_argument('--max_steps', type=int, default=40000, help='Max. number of iterations')
|
|
parser.add_argument('--num_workers', type=int, default=3, help='No. of parallel threads')
|
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
|
parser.add_argument('--step_loss', type=int, default=100, help='Decrease learning rate after how many epochs')
|
|
parser.add_argument('--lr', type=float, default=5e-4, help='Initial learning rate')
|
|
parser.add_argument('--lr_mode', default='poly', help='Learning rate policy, step or poly')
|
|
parser.add_argument('--savedir', default='../0_A2Net/results', help='Directory to save the results')
|
|
parser.add_argument('--resume', default=None, help='Use this checkpoint to continue training | '
|
|
'./results_ep100/checkpoint.pth.tar')
|
|
parser.add_argument('--logFile', default='testLog.txt',
|
|
help='File that stores the training and validation logs')
|
|
parser.add_argument('--onGPU', default=False, type=lambda x: (str(x).lower() == 'true'),
|
|
help='Run on CPU or GPU. If TRUE, then GPU.')
|
|
parser.add_argument('--weight', default='', type=str, help='pretrained weight, can be a non-strict copy')
|
|
parser.add_argument('--ms', type=int, default=0, help='apply multi-scale training, default False')
|
|
parser.add_argument('--vis_dir', default='./Predict/LEVIR/infer/', help='torchscript model path')
|
|
|
|
args = parser.parse_args()
|
|
print('Called with args:')
|
|
print(args)
|
|
|
|
ValidateSegmentation(args)
|