mdz/pytorch/yolov7/1_scripts/2_save_infer.py

302 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
# 该脚本用来加载torchscript模型并执行单张图片推理
import cv2
import os
import numpy as np
import torch
import argparse
import sys
sys.path.append(R"../0_yolov7")
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords
from utils.plots import plot_one_box
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(cls_ids[i])
score = scores[i]
if score < conf:
continue
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
cv2.rectangle(
img,
(x0, y0 + 1),
(x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
txt_bk_color,
-1
)
cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
return img
_COLORS = np.array(
[
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.000, 0.000, 0.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857,
0.000, 0.447, 0.741,
0.314, 0.717, 0.741,
0.50, 0.5, 0
]
).astype(np.float32).reshape(-1, 3)
COCO_CLASSES = (
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
)
def make_grid(nx=20, ny=20, i=0, na=3):
shape = 1, na, ny, nx, 2 # grid shape
yv, xv = torch.meshgrid(torch.arange(ny), torch.arange(nx))
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
anchor_grid = anchors[i].view((1, na, 1, 1, 2)).expand(shape)
return grid, anchor_grid
def pred_one_image(image_path, model_path, test_size, stride, anchors):
img_raw = cv2.imread(image_path)
img = letterbox(img_raw, new_shape=test_size, stride=32, auto=False)[0]
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
im = torch.from_numpy(img).float().unsqueeze(0)
im /= 255
# 加载traced模型
model = torch.jit.load(model_path)
print('model load Done')
outputs = model(im)
print('model inference Done')
nc = 80 # coco数据集类别数为80
nl = len(anchors) #检测层数目
na = 3 #每层对应的anchor数目
no = nc + 5
z = []
grid = [torch.zeros(1)] * nl
anchor_grid = [torch.zeros(1)] * nl
for i in range(nl):
bs, _ , ny, nx = outputs[i].shape
outputs[i] = outputs[i].view(bs, na, no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
grid[i], anchor_grid[i] = make_grid(nx, ny, i, na)
y = outputs[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2 + grid[i]) * stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
z.append(y.view(bs, -1, no))
pred = torch.cat(z, 1) #torch.size([1, 25200, 85])
# NMS
conf_thres = 0.25
iou_thres = 0.45
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, agnostic=False)
# 结果显示
det = pred[0]
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], img_raw.shape).round()
result_image = vis(img_raw, boxes=det[:,:4], scores=det[:,4], cls_ids=det[:,5], conf=conf_thres, class_names=COCO_CLASSES)
cv2.imshow(" ", result_image)
cv2.waitKey(0)
print('Show result image Done')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='yolov7', help='model')
parser.add_argument('--model_pt', type=str, default=R'../2_compile/fmodel/yolov7_640x640.pt', help='torchscript model path')
parser.add_argument('--source', type=str, default=R'../0_yolov7/inference/images/bus.jpg', help='image path')
parser.add_argument('--imgsz', nargs='+', type=int, default=[640], help='image size')
opt = parser.parse_args()
if opt.model in ['yolov7', 'yolov7x', 'yolov7-tiny', 'yolov7-tiny-silu']:
stride = [8, 16, 32]
elif opt.model in ['yolov7-w6', 'yolov7-d6', 'yolov7-e6', 'yolov7-e6e']:
stride = [8, 16, 32, 64]
if opt.model in ['yolov7', 'yolov7x']:
anchors = torch.tensor([[[12,16], [19,36], [40,28]], [[36,75], [76,55], [72,146]], [[142,110], [192,243], [459,401]]])
elif opt.model in ['yolov7-tiny', 'yolov7-tiny-silu']:
anchors = torch.tensor([[[10,13], [16,30], [33,23]], [[30,61], [62,45], [59,119]], [[116,90], [156,198], [373,326]]])
elif opt.model in ['yolov7-w6', 'yolov7-d6', 'yolov7-e6', 'yolov7-e6e']:
anchors = torch.tensor([[[19,27], [44,40], [38,94]], [[96, 68], [86, 152], [180, 137]], [[140, 301], [303, 264], [238, 542]], [[436, 615], [739, 380], [925, 792]]])
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
test_size = tuple(opt.imgsz)
if os.path.isfile(opt.source):
pred_one_image(opt.source, opt.model_pt, test_size, stride, anchors)
elif os.path.isdir(opt.source):
image_list = os.listdir(opt.source)
for image_file in image_list:
image_path = opt.source + "//" + image_file
pred_one_image(image_path, opt.model_pt, test_size, stride, anchors)