97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
|
|
from dewarp.models import get_model
|
|
from dewarp.utils import convert_state_dict
|
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def unwarp(img, bm):
|
|
w, h = img.shape[0], img.shape[1]
|
|
bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :]
|
|
bm0 = cv2.blur(bm[:, :, 0], (3, 3))
|
|
bm1 = cv2.blur(bm[:, :, 1], (3, 3))
|
|
bm0 = cv2.resize(bm0, (h, w))
|
|
bm1 = cv2.resize(bm1, (h, w))
|
|
bm = np.stack([bm0, bm1], axis=-1)
|
|
bm = np.expand_dims(bm, 0)
|
|
bm = torch.from_numpy(bm).double()
|
|
|
|
img = img.astype(float) / 255.0
|
|
img = img.transpose((2, 0, 1))
|
|
img = np.expand_dims(img, 0)
|
|
img = torch.from_numpy(img).double()
|
|
|
|
res = F.grid_sample(input=img, grid=bm, align_corners=True)
|
|
res = res[0].numpy().transpose((1, 2, 0))
|
|
|
|
return res
|
|
|
|
|
|
def dewarp_image(image):
|
|
wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl"
|
|
bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl"
|
|
wc_model_file_name = os.path.split(wc_model_path)[1]
|
|
wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]
|
|
|
|
bm_model_file_name = os.path.split(bm_model_path)[1]
|
|
bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]
|
|
|
|
wc_n_classes = 3
|
|
bm_n_classes = 2
|
|
|
|
wc_img_size = (256, 256)
|
|
bm_img_size = (128, 128)
|
|
|
|
# Setup image
|
|
imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(imgorg, wc_img_size)
|
|
|
|
img = img[:, :, ::-1]
|
|
img = img.astype(float) / 255.0
|
|
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
|
img = np.expand_dims(img, 0)
|
|
img = torch.from_numpy(img).float()
|
|
|
|
# Predict
|
|
htan = nn.Hardtanh(0, 1.0)
|
|
wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
|
|
if DEVICE.type == 'cpu':
|
|
wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state'])
|
|
else:
|
|
wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state'])
|
|
wc_model.load_state_dict(wc_state)
|
|
wc_model.eval()
|
|
bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
|
|
if DEVICE.type == 'cpu':
|
|
bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state'])
|
|
else:
|
|
bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state'])
|
|
bm_model.load_state_dict(bm_state)
|
|
bm_model.eval()
|
|
|
|
if torch.cuda.is_available():
|
|
wc_model.cuda()
|
|
bm_model.cuda()
|
|
images = Variable(img.cuda())
|
|
else:
|
|
images = Variable(img)
|
|
|
|
with torch.no_grad():
|
|
wc_outputs = wc_model(images)
|
|
pred_wc = htan(wc_outputs)
|
|
bm_input = F.interpolate(pred_wc, bm_img_size)
|
|
outputs_bm = bm_model(bm_input)
|
|
|
|
# call unwarp
|
|
uwpred = unwarp(imgorg, outputs_bm)
|
|
uwpred = (uwpred * 255).astype(np.uint8)
|
|
return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR)
|