202 lines
9.2 KiB
Python
202 lines
9.2 KiB
Python
#!/usr/bin/env python
|
|
import sys
|
|
sys.path.append(R"../0_spynet")
|
|
import getopt
|
|
import math
|
|
import numpy
|
|
import PIL
|
|
import PIL.Image
|
|
import sys
|
|
import torch
|
|
from flo2img import *
|
|
##########################################################
|
|
|
|
torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
|
|
|
|
torch.backends.cudnn.enabled = False # make sure to use cudnn for computational performance
|
|
INP_H = 320 #544 352 320 480 448
|
|
INP_W = 544 #960 640 576 864 832
|
|
##########################################################
|
|
WEIGHT_PATH = "../weights/network-sintel-final.pytorch"
|
|
args_strOne = '../0_spynet/images/one.png'
|
|
# args_strOne = './images/one.png'
|
|
args_strTwo = '../0_spynet/images/two.png'
|
|
# args_strTwo = './images/two.png'
|
|
|
|
for strOption, strArg in getopt.getopt(sys.argv[1:], '', [
|
|
'one=',
|
|
'two=',
|
|
])[0]:
|
|
if strOption == '--one' and strArg != '': args_strOne = strArg # path to the first frame
|
|
if strOption == '--two' and strArg != '': args_strTwo = strArg # path to the second frame
|
|
# end
|
|
|
|
##########################################################
|
|
backwarp_fix = [[2.0 /(INP_W/32 - 1.0), 2.0 /(INP_H/32 - 1.0)],
|
|
[2.0 /(INP_W/16 - 1.0), 2.0 /(INP_H/16 - 1.0)],
|
|
[2.0 /(INP_W/8 - 1.0), 2.0 /(INP_H/8 - 1.0)],
|
|
[2.0 /(INP_W/4 - 1.0), 2.0 /(INP_H/4 - 1.0)],
|
|
[2.0 /(INP_W/2 - 1.0), 2.0 /(INP_H/2 - 1.0)],
|
|
[2.0 /(INP_W - 1.0), 2.0 /(INP_H - 1.0)]]
|
|
backwarp_tenGrid = {}
|
|
|
|
def backwarp(tenInput, tenFlow,intLevel):
|
|
if str(tenFlow.shape) not in backwarp_tenGrid:
|
|
tenHor = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
|
|
tenVer = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
|
|
|
|
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1)
|
|
|
|
# end
|
|
tenFlow_1, tenFlow_2 = tenFlow.split(split_size = 1,dim = 1)
|
|
tenFlow = torch.cat([ tenFlow_1 * backwarp_fix[intLevel][0], tenFlow_2 * backwarp_fix[intLevel][1] ], 1)
|
|
|
|
return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=True)
|
|
# end
|
|
|
|
##########################################################
|
|
|
|
class Network(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
class Preprocess(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# end
|
|
|
|
def forward(self, tenInput):
|
|
tenInput = tenInput.flip([1])
|
|
tenInput = tenInput - torch.tensor(data=[0.485, 0.456, 0.406], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
|
tenInput = tenInput * torch.tensor(data=[1.0 / 0.229, 1.0 / 0.224, 1.0 / 0.225], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
|
|
|
return tenInput
|
|
# end
|
|
# end
|
|
|
|
class Basic(torch.nn.Module):
|
|
def __init__(self, intLevel):
|
|
super().__init__()
|
|
|
|
self.netBasic = torch.nn.Sequential(
|
|
torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
|
|
)
|
|
# end
|
|
|
|
def forward(self, tenInput):
|
|
return self.netBasic(tenInput)
|
|
# end
|
|
# end
|
|
|
|
self.netPreprocess = Preprocess()
|
|
self.tenUpsampled_init = torch.zeros(1, 2, int(INP_H/32), int(INP_W/32))
|
|
self.netBasic = torch.nn.ModuleList([ Basic(intLevel) for intLevel in range(6) ])
|
|
state_dict = torch.load(WEIGHT_PATH).items()
|
|
self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in state_dict},strict=False)
|
|
# end
|
|
|
|
def forward(self, tenOne, tenTwo):
|
|
nlayer=6#下采样次数
|
|
|
|
tenFlow = []
|
|
|
|
tenOne = [ tenOne ]
|
|
tenTwo = [ tenTwo ]
|
|
|
|
for intLevel in range(nlayer):
|
|
if tenOne[0].shape[2] > 32 or tenOne[0].shape[3] > 32:
|
|
tenOne.insert(0, torch.nn.functional.avg_pool2d(input=tenOne[0], kernel_size=2, stride=2, count_include_pad=False))
|
|
tenTwo.insert(0, torch.nn.functional.avg_pool2d(input=tenTwo[0], kernel_size=2, stride=2, count_include_pad=False))
|
|
# end
|
|
# end
|
|
|
|
for intLevel in range(len(tenOne)):#从小尺度到大尺度逐层计算
|
|
if intLevel == 0:
|
|
tenUpsampled = self.tenUpsampled_init
|
|
if tenUpsampled.shape[3] != tenOne[intLevel].shape[3]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 1, 0, 0 ], mode='replicate')
|
|
|
|
else:
|
|
tenUpsampled = torch.nn.functional.interpolate(input=tenFlow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
|
|
|
if tenUpsampled.shape[2] != tenOne[intLevel].shape[2]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 0, 0, 1 ], mode='replicate')
|
|
if tenUpsampled.shape[3] != tenOne[intLevel].shape[3]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 1, 0, 0 ], mode='replicate')
|
|
|
|
tenFlow = self.netBasic[intLevel](torch.cat([ tenOne[intLevel], backwarp(tenInput=tenTwo[intLevel], tenFlow=tenUpsampled,intLevel = intLevel), tenUpsampled ], 1)) + tenUpsampled
|
|
# end
|
|
tenFlow = torch.nn.functional.interpolate(tenFlow, size=(INP_H, INP_W), mode='bilinear', align_corners=False)
|
|
translation = torch.mean(tenFlow,dim=[2,3])
|
|
return tenFlow,translation
|
|
# end
|
|
# end
|
|
|
|
netNetwork = None
|
|
|
|
##########################################################
|
|
|
|
def estimate(tenOne, tenTwo):
|
|
global netNetwork
|
|
|
|
if netNetwork is None:
|
|
netNetwork = Network().eval()
|
|
# end
|
|
|
|
assert(tenOne.shape[1] == tenTwo.shape[1])
|
|
assert(tenOne.shape[2] == tenTwo.shape[2])
|
|
|
|
intWidth = tenOne.shape[2]
|
|
intHeight = tenOne.shape[1]
|
|
|
|
# assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
|
# assert(intHeight == 416) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
|
|
|
tenPreprocessedOne = tenOne.view(1, 3, intHeight, intWidth)
|
|
tenPreprocessedTwo = tenTwo.view(1, 3, intHeight, intWidth)
|
|
|
|
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0))
|
|
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0))
|
|
|
|
tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
|
tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
|
tenFirst1 = netNetwork.netPreprocess(tenPreprocessedOne)
|
|
tenSecond1 = netNetwork.netPreprocess(tenPreprocessedTwo)
|
|
# tenFirst1.permute(0, 2, 3, 1).cpu().detach().numpy().astype(numpy.float32).tofile('./ftmp/tenFirst_'+str(INP_H)+'_'+str(INP_W)+'.ftmp')
|
|
# tenSecond1.permute(0, 2, 3, 1).cpu().detach().numpy().astype(numpy.float32).tofile('./ftmp/tenSecond_'+str(INP_H)+'_'+str(INP_W)+'.ftmp')
|
|
|
|
tenFlow=netNetwork(tenFirst1, tenSecond1)[0]
|
|
|
|
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
|
|
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
|
|
|
|
return tenFlow[0, :, :, :].cpu()
|
|
# end
|
|
|
|
##########################################################
|
|
|
|
if __name__ == '__main__':
|
|
img1 = PIL.Image.open(args_strOne)
|
|
img2 = PIL.Image.open(args_strTwo)
|
|
img1 = img1.resize([INP_W,INP_H])
|
|
img2 = img2.resize([INP_W,INP_H])
|
|
|
|
tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(img1)[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
|
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(img2)[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
|
|
|
tenOutput = estimate(tenOne, tenTwo)
|
|
# objOutput = open('./out1.flo', 'wb')
|
|
# numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
|
|
# numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
|
|
# numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
|
|
# objOutput.close()
|
|
image = flow_to_image(tenOutput[ :, :, :].detach().numpy().transpose(1, 2, 0))
|
|
plt.imshow(image)
|
|
plt.show()
|
|
|
|
# end |