79 lines
3.2 KiB
Python
79 lines
3.2 KiB
Python
import argparse
|
|
import sys
|
|
sys.path.append(R"../0_rdn")
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import numpy as np
|
|
import PIL.Image as pil_image
|
|
|
|
from models import RDN
|
|
from utils import convert_rgb_to_y, denormalize, calc_psnr
|
|
INP_W = 256
|
|
INP_H = 256
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
# parser.add_argument('--weights-file', type=str, required=True)
|
|
# parser.add_argument('--image-file', type=str, required=True)
|
|
parser.add_argument('--weights-file', type=str, required=False,default="../weights/rdn_x2.pth")
|
|
parser.add_argument('--image-file', type=str, required=False,default="./Set5/HR/butterfly.png")
|
|
parser.add_argument('--num-features', type=int, default=64)
|
|
parser.add_argument('--growth-rate', type=int, default=64)
|
|
parser.add_argument('--num-blocks', type=int, default=16)
|
|
parser.add_argument('--num-layers', type=int, default=8)
|
|
parser.add_argument('--scale', type=int, default=2)
|
|
args = parser.parse_args()
|
|
|
|
cudnn.benchmark = True
|
|
device = 'cpu'
|
|
|
|
model = RDN(scale_factor=args.scale,
|
|
num_channels=3,
|
|
num_features=args.num_features,
|
|
growth_rate=args.growth_rate,
|
|
num_blocks=args.num_blocks,
|
|
num_layers=args.num_layers).to(device)
|
|
|
|
state_dict = model.state_dict()
|
|
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
|
|
if n in state_dict.keys():
|
|
state_dict[n].copy_(p)
|
|
else:
|
|
raise KeyError(n)
|
|
|
|
model.eval()
|
|
|
|
image = pil_image.open(args.image_file).convert('RGB').resize([INP_W*2,INP_H*2])
|
|
|
|
image_width = (image.width // args.scale) * args.scale
|
|
image_height = (image.height // args.scale) * args.scale
|
|
|
|
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
|
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
|
|
# lr.save(args.image_file.replace('HR.', 'LR_256x256.'.format(args.scale)))
|
|
|
|
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
|
|
# bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
|
|
lr = np.expand_dims(np.array(lr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
|
|
hr = np.expand_dims(np.array(hr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
|
|
lr = torch.from_numpy(lr).to(device)
|
|
hr = torch.from_numpy(hr).to(device)
|
|
print('inputs shape: ',INP_W,'x',INP_H)
|
|
print('upscale: ',args.scale)
|
|
|
|
with torch.no_grad():
|
|
preds = model(lr).squeeze(0)
|
|
|
|
preds_y = convert_rgb_to_y(denormalize(preds), dim_order='chw')
|
|
hr_y = convert_rgb_to_y(denormalize(hr.squeeze(0)), dim_order='chw')
|
|
|
|
preds_y = preds_y[args.scale:-args.scale, args.scale:-args.scale]
|
|
hr_y = hr_y[args.scale:-args.scale, args.scale:-args.scale]
|
|
|
|
psnr = calc_psnr(hr_y, preds_y)
|
|
print('PSNR: {:.2f}'.format(psnr))
|
|
|
|
output = pil_image.fromarray(denormalize(preds).permute(1, 2, 0).byte().cpu().numpy(),mode="RGB")
|
|
# output.save(args.image_file.replace('.', '_rdn_x{}.'.format(args.scale)))
|
|
bicubic.show("bicubic")
|
|
output.show("rdn") |