mdz/pytorch/mimo-unet/1_scripts/3_sim_infer.py

48 lines
1.5 KiB
Python

from icraft.xir import *
from icraft.xrt import *
from icraft.host_backend import *
from icraft.buyibackend import *
import numpy as np
from typing import List
import cv2
import os
def run(network: Network, input: List[Tensor]) -> List[Tensor]:
session = Session.Create([ HostBackend], network.view(0), [HostDevice.Default()])
session.apply()
output_tensors = session.forward( input ) #前向
return output_tensors
height = 720
width = 1280
dst_dir = '../3_deploy/modelzoo/mimo-unet/io'
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
GENERATED_JSON_FILE = "../3_deploy/modelzoo/mimo-unet/imodel/8/mimo-unet_720x1280_parsed.json"
GENERATED_RAW_FILE = "../3_deploy/modelzoo/mimo-unet/imodel/8/mimo-unet_720x1280_parsed.raw"
# 加载指令生成后的网络
generated_network = Network.CreateFromJsonFile(GENERATED_JSON_FILE)
generated_network.loadParamsFromFile(GENERATED_RAW_FILE)
img_path=(R"../3_deploy/modelzoo/mimo-unet/io/input/GOPR0384_11_00-000001.png")
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
img_=np.array(img).astype(np.float32).reshape(1, height, width, 3)
input_tensor = Tensor(img_, Layout("NHWC"))
try:
generated_output = run(generated_network, [input_tensor])
except InternalError as i:
print(i)
print(np.array(generated_output[2]).shape)
gen_img = np.array(generated_output[2]).astype(np.float32)
gen_img = np.squeeze(gen_img, axis=0)
gen_img = (gen_img*255).round()
dst_path = dst_dir + '/sim_parsed.png'
cv2.imwrite(dst_path,gen_img)