85 lines
3.5 KiB
Python
85 lines
3.5 KiB
Python
import argparse
|
|
import sys
|
|
sys.path.append(R"../0_rdn")
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import numpy as np
|
|
import PIL.Image as pil_image
|
|
|
|
from models import RDN
|
|
from utils import convert_rgb_to_y, denormalize, calc_psnr
|
|
INP_W = 256
|
|
INP_H = 256
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
# parser.add_argument('--weights-file', type=str, required=True)
|
|
# parser.add_argument('--image-file', type=str, required=True)
|
|
parser.add_argument('--weights-file', type=str, required=False,default="../weights/rdn_x2.pth")
|
|
parser.add_argument('--image-file', type=str, required=False,default="../2_compile/qtset/set14/img_011_SRF_2_HR.png")
|
|
parser.add_argument('--num-features', type=int, default=64)
|
|
parser.add_argument('--growth-rate', type=int, default=64)
|
|
parser.add_argument('--num-blocks', type=int, default=16)
|
|
parser.add_argument('--num-layers', type=int, default=8)
|
|
parser.add_argument('--scale', type=int, default=2)
|
|
args = parser.parse_args()
|
|
|
|
cudnn.benchmark = True
|
|
device = 'cpu'
|
|
|
|
model = RDN(scale_factor=args.scale,
|
|
num_channels=3,
|
|
num_features=args.num_features,
|
|
growth_rate=args.growth_rate,
|
|
num_blocks=args.num_blocks,
|
|
num_layers=args.num_layers).to(device)
|
|
|
|
state_dict = model.state_dict()
|
|
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
|
|
if n in state_dict.keys():
|
|
state_dict[n].copy_(p)
|
|
else:
|
|
raise KeyError(n)
|
|
|
|
model.eval()
|
|
|
|
image = pil_image.open(args.image_file).convert('RGB').resize([INP_W*2,INP_H*2])
|
|
|
|
image_width = (image.width // args.scale) * args.scale
|
|
image_height = (image.height // args.scale) * args.scale
|
|
|
|
hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
|
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
|
|
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
|
|
# bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
|
|
# bicubic.show("bicubic")
|
|
lr = np.expand_dims(np.array(lr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
|
|
hr = np.expand_dims(np.array(hr).astype(np.float32).transpose([2, 0, 1]), 0) / 255.0
|
|
lr = torch.from_numpy(lr).to(device)
|
|
hr = torch.from_numpy(hr).to(device)
|
|
print('input shape: ',INP_W,'x',INP_H)
|
|
print('upscale: ',args.scale)
|
|
|
|
with torch.no_grad():
|
|
preds = model(lr).squeeze(0)
|
|
traced_model = torch.jit.trace(model,lr)
|
|
traced_model.save("../2_compile/fmodel/rdn_2x_"+str(INP_H)+"x"+str(INP_W)+".pt")
|
|
print('output shape: ',preds.shape)
|
|
print('traced done')
|
|
# preds_y = convert_rgb_to_y(denormalize(preds), dim_order='chw')
|
|
# hr_y = convert_rgb_to_y(denormalize(hr.squeeze(0)), dim_order='chw')
|
|
|
|
# preds_y = preds_y[args.scale:-args.scale, args.scale:-args.scale]
|
|
# hr_y = hr_y[args.scale:-args.scale, args.scale:-args.scale]
|
|
|
|
# psnr = calc_psnr(hr_y, preds_y)
|
|
# print('PSNR: {:.2f}'.format(psnr))
|
|
|
|
# output = pil_image.fromarray(denormalize(preds).permute(1, 2, 0).byte().cpu().numpy())
|
|
# # output.save(args.image_file.replace('.', '_rdn_x{}.'.format(args.scale)))
|
|
# output.show("rdn")
|
|
|
|
# canvas = pil_image.new("RGB",(1920,1080),(255,255,255))
|
|
# canvas.paste(bicubic,(299,284))
|
|
# canvas.paste(output,(1110,284))
|
|
# canvas.show() |