mdz/pytorch/detr/1_scripts/1_save.py

318 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import sys
sys.path.append(R"../0_detr")
from typing import Optional, List
from torch import nn, Tensor
import argparse
import datetime
import json
import math
import random
import time
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import requests
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
import torchvision.transforms as T
import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
from models.position_encoding import PositionEmbeddingSine
from models.transformer import Transformer,build_transformer,TransformerDecoder
WEIGHT = '../weights/detr-r50-e632da11.pth'
IMG_PATH = '../2_compile/qtset/detr/000000000632.jpg'
IMG_H = 640
IMG_W = 640
TRACE = True
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from icraft_models.detr_icraft_1in import build as build_icraft
build_model = build_icraft
# def forward_icraft(self, mask): # revised
# mask = mask
# assert mask is not None
# not_mask = ~mask
# y_embed = not_mask.cumsum(1, dtype=torch.float32)
# x_embed = not_mask.cumsum(2, dtype=torch.float32)
# if self.normalize:
# eps = 1e-6
# y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
# x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
# dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=mask.device)
# dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
# pos_x = x_embed[:, :, :, None] / dim_t
# pos_y = y_embed[:, :, :, None] / dim_t
# pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
# pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
# pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# return pos
# PositionEmbeddingSine.forward = forward_icraft
def forward_icraft2(self, src, tgt, query_embed, pos_embed):
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=None,
pos=pos_embed, query_pos=query_embed)
return hs.transpose(0, 1), memory.permute(1, 2, 0).view(bs, c, h, w)
Transformer.forward=forward_icraft2
def forward_icraft3(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
# return output.unsqueeze(0) # icraft 3.0 1213
return output
TransformerDecoder.forward = forward_icraft3
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# standard PyTorch mean-std input image normalization
# transform = T.Compose([
# T.Resize(800),
# T.ToTensor(),
# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
transform = T.Compose([
T.Resize([IMG_H,IMG_W]),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def detect(im,samples, model, transform):
outputs_class,outputs_coord= model(samples)
if TRACE:
torch.onnx.export(model,samples,"../2_compile/fmodel/detr_"+str(IMG_H)+"x"+str(IMG_W)+".onnx")
print("Trace Done ! Traced model is saved to "+"../2_compile/fmodel/detr_"+str(IMG_H)+"x"+str(IMG_W)+".onnx")
# keep only predictions with 0.7+ confidence
probas = outputs_class.softmax(-1)[0, :, :-1] # 预测部分直接去预测 和匈牙利算法没有关系
keep = probas.max(-1).values > 0.7
print(outputs_coord)
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs_coord[0, keep], im.size)
return probas[keep], bboxes_scaled
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
# Model parameters
parser.add_argument('--frozen_weights', type=str, default=None,
help="Path to the pretrained model. If set, only the mask head will be trained")
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
# * Segmentation
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss',default=False,type=bool,
help="Disables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--mask_loss_coef', default=1, type=float)
parser.add_argument('--dice_loss_coef', default=1, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")
# dataset parameters
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--coco_path', default='\\192.168.125.235\\nb2042\\Dataset\\COCO2017',type=str)
parser.add_argument('--coco_panoptic_path', type=str)
parser.add_argument('--remove_difficult', action='store_true')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cpu',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default=WEIGHT, help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval',default="True", type=bool)
parser.add_argument('--num_workers', default=2, type=int)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()
# plt.savefig("res.png")
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# mean-std normalize the input image (batch-size: 1)
im = Image.open(IMG_PATH)
img = transform(im).unsqueeze(0)
print(img.shape)
# demo model only support by default images with aspect ratio between 0.5 and 2
# if you want to use images with an aspect ratio outside this range
# rescale your image so that the maximum size is at most 1333 for best results
assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
tgt = torch.zeros([100,1,256]).to(img.device)
mask = torch.zeros((1, math.ceil(img.shape[-2]/32), math.ceil(img.shape[-1]/32)), dtype=torch.bool, device=img.device)# 其实mask可以从外面获取
model, postprocessors = build_model(args,tgt,mask)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
# model.to(device)
# im = cv2.imread("E:\pyproject\inLab_python\pose\image\\1.jpg") # 1 hand
model.eval()
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# im = Image.open(requests.get(url, stream=True).raw)
# im = Image.open(requests.get(url, stream=True).raw)
scores, boxes = detect(im, img, model, transform)
plot_results(im, scores, boxes)
if __name__ == '__main__':
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)