49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
from icraft.xir import *
|
|
from icraft.xrt import *
|
|
from icraft.host_backend import *
|
|
from icraft.buyibackend import *
|
|
import numpy as np
|
|
from typing import List
|
|
import cv2
|
|
import os
|
|
|
|
def run(network: Network, input: List[Tensor]) -> List[Tensor]:
|
|
session = Session.Create([ HostBackend], network.view(0), [HostDevice.Default()])
|
|
session.apply()
|
|
output_tensors = session.forward( input ) #前向
|
|
return output_tensors
|
|
|
|
height = 160
|
|
width = 240
|
|
|
|
dst_dir = '../3_deploy/modelzoo/rfdn/io/test'
|
|
if not os.path.exists(dst_dir):
|
|
os.mkdir(dst_dir)
|
|
GENERATED_JSON_FILE = "../3_deploy/modelzoo/rfdn/imodel/16/rfdn_160x240_parsed.json"
|
|
GENERATED_RAW_FILE = "../3_deploy/modelzoo/rfdn/imodel/16/rfdn_160x240_parsed.raw"
|
|
|
|
|
|
# 加载指令生成后的网络
|
|
generated_network = Network.CreateFromJsonFile(GENERATED_JSON_FILE)
|
|
generated_network.loadParamsFromFile(GENERATED_RAW_FILE)
|
|
|
|
img_path=(R"../2_compile/qtset/sr/3096x2.png")
|
|
img = cv2.imread(img_path)
|
|
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
|
|
img_=np.array(img).astype(np.float32).reshape(1, height, width, 3)
|
|
|
|
input_tensor = Tensor(img_, Layout("NHWC"))
|
|
try:
|
|
generated_output = run(generated_network, [input_tensor])
|
|
except InternalError as i:
|
|
print(i)
|
|
|
|
print(np.array(generated_output[0]).shape)
|
|
|
|
|
|
# print(np.array(generated_output[0]).transpose(0,3,1,2))
|
|
gen_img = np.array(generated_output[0]).astype(np.float32)
|
|
gen_img = np.squeeze(gen_img, axis=0)
|
|
gen_img = cv2.cvtColor(gen_img,cv2.COLOR_RGB2BGR)
|
|
dst_path = dst_dir + '/fly_16null_tensor_parsed.png'
|
|
cv2.imwrite(dst_path,gen_img) |