import cv2 import paddle from .GeoTr import GeoTr from .utils import to_tensor, to_image def dewarp_image(image): model_path = "model/dewarp_model/best.ckpt" checkpoint = paddle.load(model_path) state_dict = checkpoint["model"] model = GeoTr() model.set_state_dict(state_dict) model.eval() img = cv2.resize(image, (288, 288)) x = to_tensor(img) y = to_tensor(image) bm = model(x) bm = paddle.nn.functional.interpolate( bm, y.shape[2:], mode="bilinear", align_corners=False ) bm_nhwc = bm.transpose([0, 2, 3, 1]) out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2) return to_image(out)