mdz/pytorch/strongReid/1_scripts/utils/data_infer.py

278 lines
9.7 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from collections import defaultdict
from torch.utils.data.sampler import Sampler
import copy
import numpy as np
import random
import torch
import torchvision.transforms as T
import os.path as osp
from PIL import Image
import glob
import re
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, data_source, batch_size, num_instances):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(self.data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
# estimate number of examples in an epoch
self.length = 0
for pid in self.pids:
idxs = self.index_dic[pid]
num = len(idxs)
if num < self.num_instances:
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
self.length = len(final_idxs)
return iter(final_idxs)
def __len__(self):
return self.length
class BaseDataset(object):
"""
Base class of reid dataset
"""
def get_imagedata_info(self, data):
pids, cams = [], []
for _, pid, camid in data:
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
return num_pids, num_imgs, num_cams
def get_videodata_info(self, data, return_tracklet_stats=False):
pids, cams, tracklet_stats = [], [], []
for img_paths, pid, camid in data:
pids += [pid]
cams += [camid]
tracklet_stats += [len(img_paths)]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_tracklets = len(data)
if return_tracklet_stats:
return num_pids, num_tracklets, num_cams, tracklet_stats
return num_pids, num_tracklets, num_cams
def print_dataset_statistics(self):
raise NotImplementedError
class BaseImageDataset(BaseDataset):
"""
Base class of image reid dataset
"""
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
print("Dataset statistics:")
print(" ----------------------------------------")
print(" subset | # ids | # images | # cameras")
print(" ----------------------------------------")
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
print(" ----------------------------------------")
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not osp.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_path, pid, camid = self.dataset[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid, img_path
class Market1501(BaseImageDataset):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
dataset_dir = ''
def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs):
super(Market1501, self).__init__()
self.dataset_dir = osp.join(root, self.dataset_dir)
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'gallery')
self._check_before_run()
query = self._process_dir(self.query_dir, relabel=False)
gallery = self._process_dir(self.gallery_dir, relabel=False)
self.query = query
self.gallery = gallery
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("'{}' is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
def _process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1: continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1: continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= camid <= 6
camid -= 1 # index starts from 0
if relabel: pid = pid2label[pid]
dataset.append((img_path, pid, camid))
return dataset
__factory = {
'market1501': Market1501,
}
def train_collate_fn(batch):
imgs, pids, _, _, = zip(*batch)
pids = torch.tensor(pids, dtype=torch.int64)
return torch.stack(imgs, dim=0), pids
def val_collate_fn(batch):
imgs, pids, camids, _ = zip(*batch)
# return torch.stack(imgs, dim=0), pids, camids
return imgs, pids, camids
def init_dataset(name, *args, **kwargs):
if name not in __factory.keys():
raise KeyError("Unknown datasets: {}".format(name))
return __factory[name](*args, **kwargs)
normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
T.Resize([256, 128]),
# T.ToTensor(),
# normalize_transform
])
def make_val_data_loader(cfg):
val_transforms = transform
num_workers = cfg.DATALOADER.NUM_WORKERS
if len(cfg.DATASETS.NAMES) == 1:
dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
else:
# TODO: add multi dataset to train
dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
num_classes = dataset.num_gallery_pids
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
val_loader = DataLoader(
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
collate_fn=val_collate_fn
)
return val_loader,len(dataset.query), num_classes