27 lines
675 B
Python
27 lines
675 B
Python
import cv2
|
|
import paddle
|
|
|
|
from .GeoTr import GeoTr
|
|
from .utils import to_tensor, to_image
|
|
|
|
|
|
def dewarp_image(image):
|
|
model_path = "model/dewarp_model/best.ckpt"
|
|
|
|
checkpoint = paddle.load(model_path)
|
|
state_dict = checkpoint["model"]
|
|
model = GeoTr()
|
|
model.set_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
img = cv2.resize(image, (288, 288))
|
|
x = to_tensor(img)
|
|
y = to_tensor(image)
|
|
bm = model(x)
|
|
bm = paddle.nn.functional.interpolate(
|
|
bm, y.shape[2:], mode="bilinear", align_corners=False
|
|
)
|
|
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
|
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
|
return to_image(out)
|