mdz/pytorch/rdn/1_scripts/0_infer.py

79 lines
3.2 KiB
Python

import argparse
import sys
sys.path.append(R"../0_rdn")
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from models import RDN
from utils import convert_rgb_to_y, denormalize, calc_psnr
INP_W = 256
INP_H = 256
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--weights-file', type=str, required=True)
# parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--weights-file', type=str, required=False,default="../weights/rdn_x2.pth")
parser.add_argument('--image-file', type=str, required=False,default="./Set5/HR/butterfly.png")
parser.add_argument('--num-features', type=int, default=64)
parser.add_argument('--growth-rate', type=int, default=64)
parser.add_argument('--num-blocks', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=8)
parser.add_argument('--scale', type=int, default=2)
args = parser.parse_args()
cudnn.benchmark = True
device = 'cpu'
model = RDN(scale_factor=args.scale,
num_channels=3,
num_features=args.num_features,
growth_rate=args.growth_rate,
num_blocks=args.num_blocks,
num_layers=args.num_layers).to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB').resize([INP_W*2,INP_H*2])
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
# lr.save(args.image_file.replace('HR.', 'LR_256x256.'.format(args.scale)))
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
# bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
lr = np.expand_dims(np.array(lr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
hr = np.expand_dims(np.array(hr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
lr = torch.from_numpy(lr).to(device)
hr = torch.from_numpy(hr).to(device)
print('inputs shape: ',INP_W,'x',INP_H)
print('upscale: ',args.scale)
with torch.no_grad():
preds = model(lr).squeeze(0)
preds_y = convert_rgb_to_y(denormalize(preds), dim_order='chw')
hr_y = convert_rgb_to_y(denormalize(hr.squeeze(0)), dim_order='chw')
preds_y = preds_y[args.scale:-args.scale, args.scale:-args.scale]
hr_y = hr_y[args.scale:-args.scale, args.scale:-args.scale]
psnr = calc_psnr(hr_y, preds_y)
print('PSNR: {:.2f}'.format(psnr))
output = pil_image.fromarray(denormalize(preds).permute(1, 2, 0).byte().cpu().numpy(),mode="RGB")
# output.save(args.image_file.replace('.', '_rdn_x{}.'.format(args.scale)))
bicubic.show("bicubic")
output.show("rdn")