import os import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from dewarp.models import get_model from dewarp.utils import convert_state_dict DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def unwarp(img, bm): w, h = img.shape[0], img.shape[1] bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :] bm0 = cv2.blur(bm[:, :, 0], (3, 3)) bm1 = cv2.blur(bm[:, :, 1], (3, 3)) bm0 = cv2.resize(bm0, (h, w)) bm1 = cv2.resize(bm1, (h, w)) bm = np.stack([bm0, bm1], axis=-1) bm = np.expand_dims(bm, 0) bm = torch.from_numpy(bm).double() img = img.astype(float) / 255.0 img = img.transpose((2, 0, 1)) img = np.expand_dims(img, 0) img = torch.from_numpy(img).double() res = F.grid_sample(input=img, grid=bm, align_corners=True) res = res[0].numpy().transpose((1, 2, 0)) return res def dewarp_image(image): wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl" bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl" wc_model_file_name = os.path.split(wc_model_path)[1] wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')] bm_model_file_name = os.path.split(bm_model_path)[1] bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')] wc_n_classes = 3 bm_n_classes = 2 wc_img_size = (256, 256) bm_img_size = (128, 128) # Setup image imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) img = cv2.resize(imgorg, wc_img_size) img = img[:, :, ::-1] img = img.astype(float) / 255.0 img = img.transpose(2, 0, 1) # NHWC -> NCHW img = np.expand_dims(img, 0) img = torch.from_numpy(img).float() # Predict htan = nn.Hardtanh(0, 1.0) wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3) if DEVICE.type == 'cpu': wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state']) else: wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state']) wc_model.load_state_dict(wc_state) wc_model.eval() bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3) if DEVICE.type == 'cpu': bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state']) else: bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state']) bm_model.load_state_dict(bm_state) bm_model.eval() if torch.cuda.is_available(): wc_model.cuda() bm_model.cuda() images = Variable(img.cuda()) else: images = Variable(img) with torch.no_grad(): wc_outputs = wc_model(images) pred_wc = htan(wc_outputs) bm_input = F.interpolate(pred_wc, bm_img_size) outputs_bm = bm_model(bm_input) # call unwarp uwpred = unwarp(imgorg, outputs_bm) uwpred = (uwpred * 255).astype(np.uint8) return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR)