dewarpNet矫正扭曲

This commit is contained in:
2024-08-12 08:38:17 +08:00
parent 6fd5c059c2
commit 4fabb1a1e9
6 changed files with 469 additions and 0 deletions

96
dewarp/dewarp.py Normal file
View File

@@ -0,0 +1,96 @@
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from dewarp.models import get_model
from dewarp.utils import convert_state_dict
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def unwarp(img, bm):
w, h = img.shape[0], img.shape[1]
bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :]
bm0 = cv2.blur(bm[:, :, 0], (3, 3))
bm1 = cv2.blur(bm[:, :, 1], (3, 3))
bm0 = cv2.resize(bm0, (h, w))
bm1 = cv2.resize(bm1, (h, w))
bm = np.stack([bm0, bm1], axis=-1)
bm = np.expand_dims(bm, 0)
bm = torch.from_numpy(bm).double()
img = img.astype(float) / 255.0
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).double()
res = F.grid_sample(input=img, grid=bm, align_corners=True)
res = res[0].numpy().transpose((1, 2, 0))
return res
def dewarp_image(image):
wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl"
bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl"
wc_model_file_name = os.path.split(wc_model_path)[1]
wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]
bm_model_file_name = os.path.split(bm_model_path)[1]
bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]
wc_n_classes = 3
bm_n_classes = 2
wc_img_size = (256, 256)
bm_img_size = (128, 128)
# Setup image
imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = cv2.resize(imgorg, wc_img_size)
img = img[:, :, ::-1]
img = img.astype(float) / 255.0
img = img.transpose(2, 0, 1) # NHWC -> NCHW
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).float()
# Predict
htan = nn.Hardtanh(0, 1.0)
wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
if DEVICE.type == 'cpu':
wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state'])
else:
wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state'])
wc_model.load_state_dict(wc_state)
wc_model.eval()
bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
if DEVICE.type == 'cpu':
bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state'])
else:
bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state'])
bm_model.load_state_dict(bm_state)
bm_model.eval()
if torch.cuda.is_available():
wc_model.cuda()
bm_model.cuda()
images = Variable(img.cuda())
else:
images = Variable(img)
with torch.no_grad():
wc_outputs = wc_model(images)
pred_wc = htan(wc_outputs)
bm_input = F.interpolate(pred_wc, bm_img_size)
outputs_bm = bm_model(bm_input)
# call unwarp
uwpred = unwarp(imgorg, outputs_bm)
uwpred = (uwpred * 255).astype(np.uint8)
return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR)