dewarpNet矫正扭曲
This commit is contained in:
96
dewarp/dewarp.py
Normal file
96
dewarp/dewarp.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from dewarp.models import get_model
|
||||
from dewarp.utils import convert_state_dict
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def unwarp(img, bm):
|
||||
w, h = img.shape[0], img.shape[1]
|
||||
bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :]
|
||||
bm0 = cv2.blur(bm[:, :, 0], (3, 3))
|
||||
bm1 = cv2.blur(bm[:, :, 1], (3, 3))
|
||||
bm0 = cv2.resize(bm0, (h, w))
|
||||
bm1 = cv2.resize(bm1, (h, w))
|
||||
bm = np.stack([bm0, bm1], axis=-1)
|
||||
bm = np.expand_dims(bm, 0)
|
||||
bm = torch.from_numpy(bm).double()
|
||||
|
||||
img = img.astype(float) / 255.0
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.expand_dims(img, 0)
|
||||
img = torch.from_numpy(img).double()
|
||||
|
||||
res = F.grid_sample(input=img, grid=bm, align_corners=True)
|
||||
res = res[0].numpy().transpose((1, 2, 0))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def dewarp_image(image):
|
||||
wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl"
|
||||
bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl"
|
||||
wc_model_file_name = os.path.split(wc_model_path)[1]
|
||||
wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]
|
||||
|
||||
bm_model_file_name = os.path.split(bm_model_path)[1]
|
||||
bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]
|
||||
|
||||
wc_n_classes = 3
|
||||
bm_n_classes = 2
|
||||
|
||||
wc_img_size = (256, 256)
|
||||
bm_img_size = (128, 128)
|
||||
|
||||
# Setup image
|
||||
imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(imgorg, wc_img_size)
|
||||
|
||||
img = img[:, :, ::-1]
|
||||
img = img.astype(float) / 255.0
|
||||
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
||||
img = np.expand_dims(img, 0)
|
||||
img = torch.from_numpy(img).float()
|
||||
|
||||
# Predict
|
||||
htan = nn.Hardtanh(0, 1.0)
|
||||
wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
|
||||
if DEVICE.type == 'cpu':
|
||||
wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state'])
|
||||
else:
|
||||
wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state'])
|
||||
wc_model.load_state_dict(wc_state)
|
||||
wc_model.eval()
|
||||
bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
|
||||
if DEVICE.type == 'cpu':
|
||||
bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state'])
|
||||
else:
|
||||
bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state'])
|
||||
bm_model.load_state_dict(bm_state)
|
||||
bm_model.eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
wc_model.cuda()
|
||||
bm_model.cuda()
|
||||
images = Variable(img.cuda())
|
||||
else:
|
||||
images = Variable(img)
|
||||
|
||||
with torch.no_grad():
|
||||
wc_outputs = wc_model(images)
|
||||
pred_wc = htan(wc_outputs)
|
||||
bm_input = F.interpolate(pred_wc, bm_img_size)
|
||||
outputs_bm = bm_model(bm_input)
|
||||
|
||||
# call unwarp
|
||||
uwpred = unwarp(imgorg, outputs_bm)
|
||||
uwpred = (uwpred * 255).astype(np.uint8)
|
||||
return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR)
|
||||
Reference in New Issue
Block a user