302 lines
8.9 KiB
Python
302 lines
8.9 KiB
Python
# -*- coding:utf-8 -*-
|
||
# 该脚本用来加载torchscript模型,并执行单张图片推理
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
import argparse
|
||
import sys
|
||
sys.path.append(R"../0_yolov7")
|
||
from utils.datasets import letterbox
|
||
from utils.general import non_max_suppression, scale_coords
|
||
from utils.plots import plot_one_box
|
||
|
||
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
|
||
|
||
for i in range(len(boxes)):
|
||
box = boxes[i]
|
||
cls_id = int(cls_ids[i])
|
||
score = scores[i]
|
||
if score < conf:
|
||
continue
|
||
x0 = int(box[0])
|
||
y0 = int(box[1])
|
||
x1 = int(box[2])
|
||
y1 = int(box[3])
|
||
|
||
color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
|
||
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
|
||
txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
|
||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||
|
||
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
|
||
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
|
||
|
||
txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
|
||
cv2.rectangle(
|
||
img,
|
||
(x0, y0 + 1),
|
||
(x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
|
||
txt_bk_color,
|
||
-1
|
||
)
|
||
cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
|
||
|
||
return img
|
||
|
||
|
||
_COLORS = np.array(
|
||
[
|
||
0.000, 0.447, 0.741,
|
||
0.850, 0.325, 0.098,
|
||
0.929, 0.694, 0.125,
|
||
0.494, 0.184, 0.556,
|
||
0.466, 0.674, 0.188,
|
||
0.301, 0.745, 0.933,
|
||
0.635, 0.078, 0.184,
|
||
0.300, 0.300, 0.300,
|
||
0.600, 0.600, 0.600,
|
||
1.000, 0.000, 0.000,
|
||
1.000, 0.500, 0.000,
|
||
0.749, 0.749, 0.000,
|
||
0.000, 1.000, 0.000,
|
||
0.000, 0.000, 1.000,
|
||
0.667, 0.000, 1.000,
|
||
0.333, 0.333, 0.000,
|
||
0.333, 0.667, 0.000,
|
||
0.333, 1.000, 0.000,
|
||
0.667, 0.333, 0.000,
|
||
0.667, 0.667, 0.000,
|
||
0.667, 1.000, 0.000,
|
||
1.000, 0.333, 0.000,
|
||
1.000, 0.667, 0.000,
|
||
1.000, 1.000, 0.000,
|
||
0.000, 0.333, 0.500,
|
||
0.000, 0.667, 0.500,
|
||
0.000, 1.000, 0.500,
|
||
0.333, 0.000, 0.500,
|
||
0.333, 0.333, 0.500,
|
||
0.333, 0.667, 0.500,
|
||
0.333, 1.000, 0.500,
|
||
0.667, 0.000, 0.500,
|
||
0.667, 0.333, 0.500,
|
||
0.667, 0.667, 0.500,
|
||
0.667, 1.000, 0.500,
|
||
1.000, 0.000, 0.500,
|
||
1.000, 0.333, 0.500,
|
||
1.000, 0.667, 0.500,
|
||
1.000, 1.000, 0.500,
|
||
0.000, 0.333, 1.000,
|
||
0.000, 0.667, 1.000,
|
||
0.000, 1.000, 1.000,
|
||
0.333, 0.000, 1.000,
|
||
0.333, 0.333, 1.000,
|
||
0.333, 0.667, 1.000,
|
||
0.333, 1.000, 1.000,
|
||
0.667, 0.000, 1.000,
|
||
0.667, 0.333, 1.000,
|
||
0.667, 0.667, 1.000,
|
||
0.667, 1.000, 1.000,
|
||
1.000, 0.000, 1.000,
|
||
1.000, 0.333, 1.000,
|
||
1.000, 0.667, 1.000,
|
||
0.333, 0.000, 0.000,
|
||
0.500, 0.000, 0.000,
|
||
0.667, 0.000, 0.000,
|
||
0.833, 0.000, 0.000,
|
||
1.000, 0.000, 0.000,
|
||
0.000, 0.167, 0.000,
|
||
0.000, 0.333, 0.000,
|
||
0.000, 0.500, 0.000,
|
||
0.000, 0.667, 0.000,
|
||
0.000, 0.833, 0.000,
|
||
0.000, 1.000, 0.000,
|
||
0.000, 0.000, 0.167,
|
||
0.000, 0.000, 0.333,
|
||
0.000, 0.000, 0.500,
|
||
0.000, 0.000, 0.667,
|
||
0.000, 0.000, 0.833,
|
||
0.000, 0.000, 1.000,
|
||
0.000, 0.000, 0.000,
|
||
0.143, 0.143, 0.143,
|
||
0.286, 0.286, 0.286,
|
||
0.429, 0.429, 0.429,
|
||
0.571, 0.571, 0.571,
|
||
0.714, 0.714, 0.714,
|
||
0.857, 0.857, 0.857,
|
||
0.000, 0.447, 0.741,
|
||
0.314, 0.717, 0.741,
|
||
0.50, 0.5, 0
|
||
]
|
||
).astype(np.float32).reshape(-1, 3)
|
||
|
||
COCO_CLASSES = (
|
||
"person",
|
||
"bicycle",
|
||
"car",
|
||
"motorcycle",
|
||
"airplane",
|
||
"bus",
|
||
"train",
|
||
"truck",
|
||
"boat",
|
||
"traffic light",
|
||
"fire hydrant",
|
||
"stop sign",
|
||
"parking meter",
|
||
"bench",
|
||
"bird",
|
||
"cat",
|
||
"dog",
|
||
"horse",
|
||
"sheep",
|
||
"cow",
|
||
"elephant",
|
||
"bear",
|
||
"zebra",
|
||
"giraffe",
|
||
"backpack",
|
||
"umbrella",
|
||
"handbag",
|
||
"tie",
|
||
"suitcase",
|
||
"frisbee",
|
||
"skis",
|
||
"snowboard",
|
||
"sports ball",
|
||
"kite",
|
||
"baseball bat",
|
||
"baseball glove",
|
||
"skateboard",
|
||
"surfboard",
|
||
"tennis racket",
|
||
"bottle",
|
||
"wine glass",
|
||
"cup",
|
||
"fork",
|
||
"knife",
|
||
"spoon",
|
||
"bowl",
|
||
"banana",
|
||
"apple",
|
||
"sandwich",
|
||
"orange",
|
||
"broccoli",
|
||
"carrot",
|
||
"hot dog",
|
||
"pizza",
|
||
"donut",
|
||
"cake",
|
||
"chair",
|
||
"couch",
|
||
"potted plant",
|
||
"bed",
|
||
"dining table",
|
||
"toilet",
|
||
"tv",
|
||
"laptop",
|
||
"mouse",
|
||
"remote",
|
||
"keyboard",
|
||
"cell phone",
|
||
"microwave",
|
||
"oven",
|
||
"toaster",
|
||
"sink",
|
||
"refrigerator",
|
||
"book",
|
||
"clock",
|
||
"vase",
|
||
"scissors",
|
||
"teddy bear",
|
||
"hair drier",
|
||
"toothbrush",
|
||
)
|
||
|
||
def make_grid(nx=20, ny=20, i=0, na=3):
|
||
shape = 1, na, ny, nx, 2 # grid shape
|
||
yv, xv = torch.meshgrid(torch.arange(ny), torch.arange(nx))
|
||
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
||
anchor_grid = anchors[i].view((1, na, 1, 1, 2)).expand(shape)
|
||
return grid, anchor_grid
|
||
|
||
def pred_one_image(image_path, model_path, test_size, stride, anchors):
|
||
img_raw = cv2.imread(image_path)
|
||
img = letterbox(img_raw, new_shape=test_size, stride=32, auto=False)[0]
|
||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||
img = np.ascontiguousarray(img)
|
||
im = torch.from_numpy(img).float().unsqueeze(0)
|
||
im /= 255
|
||
|
||
# 加载traced模型
|
||
model = torch.jit.load(model_path)
|
||
print('model load Done')
|
||
outputs = model(im)
|
||
print('model inference Done')
|
||
|
||
nc = 80 # coco数据集类别数为80
|
||
nl = len(anchors) #检测层数目
|
||
na = 3 #每层对应的anchor数目
|
||
no = nc + 5
|
||
|
||
z = []
|
||
grid = [torch.zeros(1)] * nl
|
||
anchor_grid = [torch.zeros(1)] * nl
|
||
|
||
for i in range(nl):
|
||
bs, _ , ny, nx = outputs[i].shape
|
||
outputs[i] = outputs[i].view(bs, na, no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
||
|
||
grid[i], anchor_grid[i] = make_grid(nx, ny, i, na)
|
||
y = outputs[i].sigmoid()
|
||
y[..., 0:2] = (y[..., 0:2] * 2 + grid[i]) * stride[i] # xy
|
||
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
|
||
z.append(y.view(bs, -1, no))
|
||
|
||
pred = torch.cat(z, 1) #torch.size([1, 25200, 85])
|
||
|
||
# NMS
|
||
conf_thres = 0.25
|
||
iou_thres = 0.45
|
||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
|
||
|
||
# 结果显示
|
||
det = pred[0]
|
||
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], img_raw.shape).round()
|
||
|
||
result_image = vis(img_raw, boxes=det[:,:4], scores=det[:,4], cls_ids=det[:,5], conf=conf_thres, class_names=COCO_CLASSES)
|
||
cv2.imshow(" ", result_image)
|
||
cv2.waitKey(0)
|
||
print('Show result image Done')
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--model', type=str, default='yolov7', help='model')
|
||
parser.add_argument('--model_pt', type=str, default=R'../2_compile/fmodel/yolov7_640x640.pt', help='torchscript model path')
|
||
parser.add_argument('--source', type=str, default=R'../0_yolov7/inference/images/bus.jpg', help='image path')
|
||
parser.add_argument('--imgsz', nargs='+', type=int, default=[640], help='image size')
|
||
opt = parser.parse_args()
|
||
|
||
if opt.model in ['yolov7', 'yolov7x', 'yolov7-tiny', 'yolov7-tiny-silu']:
|
||
stride = [8, 16, 32]
|
||
elif opt.model in ['yolov7-w6', 'yolov7-d6', 'yolov7-e6', 'yolov7-e6e']:
|
||
stride = [8, 16, 32, 64]
|
||
|
||
if opt.model in ['yolov7', 'yolov7x']:
|
||
anchors = torch.tensor([[[12,16], [19,36], [40,28]], [[36,75], [76,55], [72,146]], [[142,110], [192,243], [459,401]]])
|
||
elif opt.model in ['yolov7-tiny', 'yolov7-tiny-silu']:
|
||
anchors = torch.tensor([[[10,13], [16,30], [33,23]], [[30,61], [62,45], [59,119]], [[116,90], [156,198], [373,326]]])
|
||
elif opt.model in ['yolov7-w6', 'yolov7-d6', 'yolov7-e6', 'yolov7-e6e']:
|
||
anchors = torch.tensor([[[19,27], [44,40], [38,94]], [[96, 68], [86, 152], [180, 137]], [[140, 301], [303, 264], [238, 542]], [[436, 615], [739, 380], [925, 792]]])
|
||
|
||
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
|
||
test_size = tuple(opt.imgsz)
|
||
|
||
if os.path.isfile(opt.source):
|
||
pred_one_image(opt.source, opt.model_pt, test_size, stride, anchors)
|
||
elif os.path.isdir(opt.source):
|
||
image_list = os.listdir(opt.source)
|
||
for image_file in image_list:
|
||
image_path = opt.source + "//" + image_file
|
||
pred_one_image(image_path, opt.model_pt, test_size, stride, anchors)
|