mdz/pytorch/deeplabv3plus/1_scripts/2_save_infer.py

60 lines
1.8 KiB
Python

IMAGE_PATH = "../2_compile/qtset/voc2012/2012_004266.jpg"
TRACED_MODEL_PATH = '../2_compile/fmodel/deeplab-resnet-513x513.pt'
import os
import sys
sys.path.append(R"../0_deeplabv3plus")
import os
from modeling.deeplab import *
import argparse
from dataloaders import make_data_loader
from tqdm import tqdm
from dataloaders import custom_transforms as tr
from torchvision import transforms
from PIL import Image
import torch
import numpy as np
def voc_cmap(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap/255 if normalized else cmap
return cmap
cmap = voc_cmap()
def decode_target(mask):
"""decode semantic mask to RGB image"""
return cmap[mask]
if __name__ == "__main__":
model = torch.jit.load(TRACED_MODEL_PATH)
composed_transforms = transforms.Compose([
tr.FixScaleCrop(crop_size=513),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
_img = Image.open(IMAGE_PATH).convert('RGB')
sample = {'image': _img, 'label': _img}
convert_img = composed_transforms(sample)
output = model(convert_img['image'].unsqueeze(0))
print(output.shape)
preds = output.max(1)[1].detach().cpu().numpy()
colorized_preds = decode_target(preds).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds[0]) # to PIL Image
colorized_preds.show()