mdz/pytorch/edsr/1_scripts/3_sim_infer.py

49 lines
1.6 KiB
Python

from icraft.xir import *
from icraft.xrt import *
from icraft.host_backend import *
from icraft.buyibackend import *
import numpy as np
from typing import List
import cv2
import os
def run(network: Network, input: List[Tensor]) -> List[Tensor]:
session = Session.Create([ HostBackend], network.view(0), [HostDevice.Default()])
session.apply()
output_tensors = session.forward( input ) #前向
return output_tensors
height = 160
width = 240
dst_dir = '../3_deploy/modelzoo/edsr/io/test'
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
GENERATED_JSON_FILE = "../3_deploy/modelzoo/edsr/imodel/8/edsr_gelu_160x240_quantized.json"
GENERATED_RAW_FILE = "../3_deploy/modelzoo/edsr/imodel/8/edsr_gelu_160x240_quantized.raw"
# 加载指令生成后的网络
generated_network = Network.CreateFromJsonFile(GENERATED_JSON_FILE)
generated_network.loadParamsFromFile(GENERATED_RAW_FILE)
img_path=(R"../2_compile/qtset/sr/3096x2.png")
img = cv2.imread(img_path)
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
img_=np.array(img).astype(np.float32).reshape(1, height, width, 3)
input_tensor = Tensor(img_, Layout("NHWC"))
try:
generated_output = run(generated_network, [input_tensor])
except InternalError as i:
print(i)
print(np.array(generated_output[0]).shape)
# print(np.array(generated_output[0]).transpose(0,3,1,2))
gen_img = np.array(generated_output[0]).astype(np.float32)
gen_img = np.squeeze(gen_img, axis=0)
gen_img = cv2.cvtColor(gen_img,cv2.COLOR_RGB2BGR)
dst_path = dst_dir + '/fly_8null_tensor_quantized45.png'
cv2.imwrite(dst_path,gen_img)