mdz/pytorch/yolov9_pan/1_scripts/5_test.py

253 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import PIL.Image as Image
import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm
import cv2
# RGB通道会有变换所以0通道为画图用1通道只有实例2通道为计算用
# img = Image.open(pred_path)
# img.show()
class PQStatCat():
def __init__(self):
self.iou = 0.0
self.tp = 0
self.fp = 0
self.fn = 0
def __iadd__(self, pq_stat_cat):
self.iou += pq_stat_cat.iou
self.tp += pq_stat_cat.tp
self.fp += pq_stat_cat.fp
self.fn += pq_stat_cat.fn
return self
class PQStat():
def __init__(self):
self.pq_per_cat = defaultdict(PQStatCat)
def __getitem__(self, i):
return self.pq_per_cat[i]
def __iadd__(self, pq_stat):
for label, pq_stat_cat in pq_stat.pq_per_cat.items():
self.pq_per_cat[label] += pq_stat_cat
return self
def pq_average(self, categories, isthing):
pq, sq, rq, n = 0, 0, 0, 0
per_class_results = {}
for label in categories:
if label == 0:
continue
if isthing is not None:
cat_isthing = (label < 92)
if isthing != cat_isthing:
continue
iou = self.pq_per_cat[label].iou
tp = self.pq_per_cat[label].tp
fp = self.pq_per_cat[label].fp
fn = self.pq_per_cat[label].fn
if tp + fp + fn == 0:
per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0}
continue
n += 1
pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
sq_class = iou / tp if tp != 0 else 0
rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)
per_class_results[label] = {'pq': pq_class, 'sq': sq_class, 'rq': rq_class}
pq += pq_class
sq += sq_class
rq += rq_class
return {'pq': pq / n, 'sq': sq / n, 'rq': rq / n, 'n': n}, per_class_results
VOID = 0
pq_stat = PQStat()
def valuation(pred_path, gt_path):
pred_segms = {}
pan_gt = np.array(Image.open(gt_path), dtype=np.uint8)[:,:,2]
pan_pred = np.array(Image.open(pred_path), dtype=np.uint8)[:,:,2]
# pan_gt = cv2.imread(gt_path, cv2.IMREAD_UNCHANGED)[:,:,0]
# pan_pred = cv2.imread(pred_path, cv2.IMREAD_UNCHANGED)[:,:,0]
labels_pred, labels_cnt_pred = np.unique(pan_pred,return_counts=True)
for label, label_cnt in zip(labels_pred, labels_cnt_pred):
pred_segms[label] = {}
pred_segms[label]['area'] = label_cnt
pred_segms[label]['category_id'] = label
# print(pred_segms)
gt_segms = {}
labels_gt, labels_cnt_gt = np.unique(pan_gt,return_counts=True)
# print(labels_gt)
for label, label_cnt in zip(labels_gt, labels_cnt_gt):
gt_segms[label] = {}
gt_segms[label]['area'] = label_cnt
gt_segms[label]['category_id'] = label
# print(gt_segms)
pan_gt_pred = pan_gt * 256 + pan_pred
gt_pred_map = {}
labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
# print(labels)
# print(labels_cnt)
for label, intersection in zip(labels, labels_cnt):
gt_id = label // 256
pred_id = label % 256
gt_pred_map[(gt_id, pred_id)] = intersection
# print(gt_pred_map)
# count all matched pairs
gt_matched = set()
pred_matched = set()
for label_tuple, intersection in gt_pred_map.items():
gt_label, pred_label = label_tuple
if gt_label not in gt_segms:
continue
if pred_label not in pred_segms:
continue
if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']:
continue
union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
iou = intersection / union
if iou > 0.5:
pq_stat[gt_segms[gt_label]['category_id']].tp += 1
pq_stat[gt_segms[gt_label]['category_id']].iou += iou
gt_matched.add(gt_label)
pred_matched.add(pred_label)
# count false positives
for gt_label, gt_info in gt_segms.items():
if gt_label in gt_matched:
continue
pq_stat[gt_info['category_id']].fn += 1
# count false positives
for pred_label, pred_info in pred_segms.items():
if pred_label in pred_matched:
continue
# intersection of the segment with VOID
intersection = gt_pred_map.get((VOID, pred_label), 0)
# predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions
if intersection / pred_info['area'] > 0.5:
continue
pq_stat[pred_info['category_id']].fp += 1
return pq_stat
all_instances_ids = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 27, 28,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
61, 62, 63, 64, 65, 67, 70,
72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 84, 85, 86, 87, 88, 89, 90,
]
all_stuff_ids = [
92, 93, 94, 95, 96, 97, 98, 99, 100,
101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
181, 182,
# other
183,
# unlabeled
0,
]
labels_gt = all_instances_ids + all_stuff_ids
def get_corresponding_pairs(pred_path, gt_path):
pred_files = [filename for filename in os.listdir(pred_path)
if filename.endswith(('.jpg', '.jpeg', '.png', '.gif'))]
pairs = []
for pred_file in pred_files:
# 提取预测文件的编号
pred_file_number = pred_file.split('.')[0]
gt_filename = pred_file_number.lstrip('0') + '.png' # 去掉前导零并添加 .png 后缀
gt_file = os.path.join(gt_path, gt_filename)
# 检查 gt_file 是否存在于 gt_path 中
if os.path.exists(gt_file):
# 添加 (pred_file, gt_file) 对到列表中
pairs.append((os.path.join(pred_path, pred_file), gt_file))
return pairs
def calculate_and_accumulate(pq_stat, results, labels_gt):
metrics = [("All", None), ("Things", True), ("Stuff", False)]
for name, isthing in metrics:
res, _ = pq_stat.pq_average(labels_gt, isthing=isthing)
if name not in results:
results[name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0}
results[name]['pq'] += res['pq'] * res['n']
results[name]['sq'] += res['sq'] * res['n']
results[name]['rq'] += res['rq'] * res['n']
results[name]['n'] += res['n']
# pred_path = '../1_scripts/val_test'
pred_path = '../3_deploy/modelzoo/yolov9_pan/io/output_16'
gt_path = '../1_scripts/coco_val'
corresponding_pairs = get_corresponding_pairs(pred_path, gt_path)
for pair in tqdm(corresponding_pairs[:500], desc="loading"):
pq_stat = valuation(pair[0], pair[1])
pq_stat += pq_stat
metrics = [("All", None), ("Things", True), ("Stuff", False)]
results = {}
# labels_gt = [0 ,1 ,35 ,120 ,159]
for name, isthing in metrics:
results[name], per_class_results = pq_stat.pq_average(labels_gt, isthing=isthing)
if name == 'All':
results['per_class'] = per_class_results
# print(results['All']['pq'])
metrics = ["All", "Things", "Stuff"]
print("-" * 41)
print("{:14s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N"))
for name in metrics:
print("{:14s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format(
name,
100 * results[name]['pq'],
100 * results[name]['sq'],
100 * results[name]['rq'],
results[name]['n']
))
# results = {}
# batch_size = 1
# for i in tqdm(range(0, len(corresponding_pairs[:10]), batch_size), desc="Processing in batches"):
# pq_stat = PQStat()
# batch_pairs = corresponding_pairs[i:i + batch_size]
# for pair in batch_pairs:
# pq_stat += valuation(pair[0], pair[1])
# calculate_and_accumulate(pq_stat, results, labels_gt)
# print(results)
# final_results = {}
# for name, res in results.items():
# final_results[name] = {
# 'pq': res['pq'] / res['n'],
# 'sq': res['sq'] / res['n'],
# 'rq': res['rq'] / res['n'],
# 'n': res['n']
# }
# print(final_results)