287 lines
9.7 KiB
Python
287 lines
9.7 KiB
Python
TRACED_MODEL_PATH = R"../2_compile/fmodel/Selfmodel-A-selfattn-1c.pt"
|
|
WEIGHTS_PATH = R'../weights/selfmodel_A_attn.pth'
|
|
INPUT_IMG = R"../0_nafnet_lite/GOPR0384_11_00-000001.png"
|
|
CFG_FILE = R'../0_nafnet_lite/options/test/GoPro/Selfmodel_A_selfattn.yml'
|
|
|
|
TRACED_MODEL_PATH = R"../2_compile/fmodel/Selfmodel-A_b18.pt"
|
|
WEIGHTS_PATH = R'../weights/selfmodel_A_b18.pth'
|
|
CFG_FILE = R'../0_nafnet_lite/options/test/GoPro/Selfmodel_A_b18.yml'
|
|
|
|
import sys
|
|
sys.path.append(R"../0_nafnet_lite")
|
|
import argparse
|
|
import torch
|
|
import numpy as np
|
|
import random
|
|
import yaml
|
|
import torch.distributed as dist
|
|
from os import path as osp
|
|
from collections import OrderedDict
|
|
|
|
from basicsr.models import create_model
|
|
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite
|
|
from basicsr.models.image_restoration_model import ImageRestorationModel
|
|
from basicsr.models.archs.Baseline_arch import Baseline
|
|
from basicsr.models.archs.NAFNet_arch import NAFNet
|
|
from ptflops import get_model_complexity_info
|
|
|
|
def new_test(self):
|
|
self.net_g.eval()
|
|
with torch.no_grad():
|
|
n = len(self.lq)
|
|
outs = []
|
|
m = self.opt['val'].get('max_minibatch', n)
|
|
i = 0
|
|
while i < n:
|
|
j = i + m
|
|
if j >= n:
|
|
j = n
|
|
|
|
if self.lq.shape[2] == 720 :
|
|
|
|
pad = torch.zeros(self.lq.shape[0], self.lq.shape[1], 16, self.lq.shape[3]).to(self.lq.device)
|
|
self.lq = torch.cat((self.lq , pad), dim=2)
|
|
pred = self.net_g(self.lq[i:j])
|
|
pred = pred[: ,: ,:720,: ]
|
|
|
|
# torch.onnx.export(self.net_g, self.lq[i:j],"Selfmodel-A-selfattn-1c.onnx",opset_version=11)
|
|
torch.jit.save(torch.jit.trace(self.net_g, self.lq[i:j]), TRACED_MODEL_PATH)
|
|
print(rf'Model saved in {TRACED_MODEL_PATH}!')
|
|
|
|
else:
|
|
pred = self.net_g(self.lq[i:j])
|
|
|
|
if isinstance(pred, list):
|
|
pred = pred[-1]
|
|
outs.append(pred.detach().cpu())
|
|
i = j
|
|
|
|
self.output = torch.cat(outs, dim=0)
|
|
self.net_g.train()
|
|
|
|
ImageRestorationModel.test = new_test
|
|
|
|
def ordered_yaml():
|
|
"""Support OrderedDict for yaml.
|
|
|
|
Returns:
|
|
yaml Loader and Dumper.
|
|
"""
|
|
try:
|
|
from yaml import CDumper as Dumper
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Dumper, Loader
|
|
|
|
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
|
|
|
def dict_representer(dumper, data):
|
|
return dumper.represent_dict(data.items())
|
|
|
|
def dict_constructor(loader, node):
|
|
return OrderedDict(loader.construct_pairs(node))
|
|
|
|
Dumper.add_representer(OrderedDict, dict_representer)
|
|
Loader.add_constructor(_mapping_tag, dict_constructor)
|
|
return Loader, Dumper
|
|
|
|
def parse(opt_path, is_train=True):
|
|
"""Parse option file.
|
|
|
|
Args:
|
|
opt_path (str): Option file path.
|
|
is_train (str): Indicate whether in training or not. Default: True.
|
|
|
|
Returns:
|
|
(dict): Options.
|
|
"""
|
|
with open(opt_path, mode='r') as f:
|
|
Loader, _ = ordered_yaml()
|
|
opt = yaml.load(f, Loader=Loader)
|
|
|
|
opt['is_train'] = is_train
|
|
|
|
# datasets
|
|
if 'datasets' in opt:
|
|
for phase, dataset in opt['datasets'].items():
|
|
# for several datasets, e.g., test_1, test_2
|
|
phase = phase.split('_')[0]
|
|
dataset['phase'] = phase
|
|
if 'scale' in opt:
|
|
dataset['scale'] = opt['scale']
|
|
if dataset.get('dataroot_gt') is not None:
|
|
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
|
if dataset.get('dataroot_lq') is not None:
|
|
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
|
|
|
# paths
|
|
for key, val in opt['path'].items():
|
|
if (val is not None) and ('resume_state' in key
|
|
or 'pretrain_network' in key):
|
|
opt['path'][key] = osp.expanduser(val)
|
|
opt['path']['root'] = osp.abspath(
|
|
osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
|
if is_train:
|
|
experiments_root = osp.join(opt['path']['root'], 'experiments',
|
|
opt['name'])
|
|
opt['path']['experiments_root'] = experiments_root
|
|
opt['path']['models'] = osp.join(experiments_root, 'models')
|
|
opt['path']['training_states'] = osp.join(experiments_root,
|
|
'training_states')
|
|
opt['path']['log'] = experiments_root
|
|
opt['path']['visualization'] = osp.join(experiments_root,
|
|
'visualization')
|
|
|
|
# change some options for debug mode
|
|
if 'debug' in opt['name']:
|
|
if 'val' in opt:
|
|
opt['val']['val_freq'] = 8
|
|
opt['logger']['print_freq'] = 1
|
|
opt['logger']['save_checkpoint_freq'] = 8
|
|
else: # test
|
|
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
|
opt['path']['results_root'] = results_root
|
|
opt['path']['log'] = results_root
|
|
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
|
|
|
return opt
|
|
|
|
def get_dist_info():
|
|
if dist.is_available():
|
|
initialized = dist.is_initialized()
|
|
else:
|
|
initialized = False
|
|
if initialized:
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
return rank, world_size
|
|
|
|
def init_dist(launcher, backend='nccl', **kwargs):
|
|
if mp.get_start_method(allow_none=True) is None:
|
|
mp.set_start_method('spawn')
|
|
if launcher == 'pytorch':
|
|
_init_dist_pytorch(backend, **kwargs)
|
|
elif launcher == 'slurm':
|
|
_init_dist_slurm(backend, **kwargs)
|
|
else:
|
|
raise ValueError(f'Invalid launcher type: {launcher}')
|
|
|
|
def set_random_seed(seed):
|
|
"""Set random seeds."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
def parse_options(is_train=True):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--opt', type=str, default= CFG_FILE, required=False, help='Path to option YAML file.')
|
|
parser.add_argument(
|
|
'--launcher',
|
|
choices=['none', 'pytorch', 'slurm'],
|
|
default='none',
|
|
help='job launcher')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
parser.add_argument('--input_path', type=str, default = INPUT_IMG, required=False, help='The path to the input image. For single image inference only.')
|
|
parser.add_argument('--output_path', type=str, default= R'deblur1280_attn_result.png', required=False, help='The path to the output image. For single image inference only.')
|
|
parser.add_argument('--weight_path', type=str, default= WEIGHTS_PATH, required=False, help='The path to weights file.')
|
|
|
|
args = parser.parse_args()
|
|
opt = parse(args.opt, is_train=is_train)
|
|
|
|
# distributed settings
|
|
if args.launcher == 'none':
|
|
opt['dist'] = False
|
|
print('Disable distributed.', flush=True)
|
|
else:
|
|
opt['dist'] = True
|
|
if args.launcher == 'slurm' and 'dist_params' in opt:
|
|
init_dist(args.launcher, **opt['dist_params'])
|
|
else:
|
|
init_dist(args.launcher)
|
|
print('init dist .. ', args.launcher)
|
|
|
|
opt['rank'], opt['world_size'] = get_dist_info()
|
|
|
|
# random seed
|
|
seed = opt.get('manual_seed')
|
|
if seed is None:
|
|
seed = random.randint(1, 10000)
|
|
opt['manual_seed'] = seed
|
|
set_random_seed(seed + opt['rank'])
|
|
|
|
if args.input_path is not None and args.output_path is not None:
|
|
opt['img_path'] = {
|
|
'input_img': args.input_path,
|
|
'output_img': args.output_path
|
|
}
|
|
|
|
opt['path']['pretrain_network_g'] = args.weight_path
|
|
|
|
return opt
|
|
|
|
def main():
|
|
# parse options, set distributed setting, set ramdom seed
|
|
opt = parse_options(is_train=False)
|
|
|
|
opt['num_gpu'] = torch.cuda.device_count()
|
|
|
|
img_path = opt['img_path'].get('input_img')
|
|
output_path = opt['img_path'].get('output_img')
|
|
|
|
## 1. read image
|
|
file_client = FileClient('disk')
|
|
|
|
img_bytes = file_client.get(img_path, None)
|
|
try:
|
|
img = imfrombytes(img_bytes, float32=True)
|
|
except:
|
|
raise Exception("path {} not working".format(img_path))
|
|
|
|
img = img2tensor(img, bgr2rgb=True, float32=True)
|
|
|
|
## 2. run inference
|
|
opt['dist'] = False
|
|
|
|
model = create_model(opt)
|
|
|
|
total_parameters = 0
|
|
for name, param in model.net_g.named_parameters():
|
|
total_parameters += param.numel()
|
|
print(f"Total Parameters(M): {total_parameters/1000/1000}")
|
|
|
|
img_channel = 3
|
|
width = 16
|
|
dw_expand = 1
|
|
ffn_expand = 1
|
|
enc_blks = [1, 1, 1, 1]
|
|
middle_blk_num = 1
|
|
dec_blks = [1, 1, 1, 1]
|
|
net = Baseline(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
|
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, dw_expand=dw_expand,
|
|
ffn_expand=ffn_expand)
|
|
|
|
inp_shape = (3, 736, 1280)
|
|
print(inp_shape)
|
|
|
|
from ptflops import get_model_complexity_info
|
|
flops, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
|
|
|
print('flops :', flops)
|
|
print('params :', params)
|
|
|
|
model.feed_data(data={'lq': img.unsqueeze(dim=0)})
|
|
|
|
if model.opt['val'].get('grids', False):
|
|
model.grids()
|
|
model.test()
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|