diff --git a/doc_dewarp/__init__.py b/doc_dewarp/__init__.py index e69de29..3f477b4 100644 --- a/doc_dewarp/__init__.py +++ b/doc_dewarp/__init__.py @@ -0,0 +1,10 @@ +import paddle + +from doc_dewarp.GeoTr import GeoTr + +model_path = "model/dewarp_model/best.ckpt" +checkpoint = paddle.load(model_path) +state_dict = checkpoint["model"] +DEWARP = GeoTr() +DEWARP.set_state_dict(state_dict) +DEWARP.eval() diff --git a/doc_dewarp/dewarp.py b/doc_dewarp/dewarp.py index 1decad9..34548bd 100644 --- a/doc_dewarp/dewarp.py +++ b/doc_dewarp/dewarp.py @@ -1,23 +1,15 @@ import cv2 import paddle -from .GeoTr import GeoTr +from . import DEWARP from .utils import to_tensor, to_image def dewarp_image(image): - model_path = "model/dewarp_model/best.ckpt" - - checkpoint = paddle.load(model_path) - state_dict = checkpoint["model"] - model = GeoTr() - model.set_state_dict(state_dict) - model.eval() - img = cv2.resize(image, (288, 288)) x = to_tensor(img) y = to_tensor(image) - bm = model(x) + bm = DEWARP(x) bm = paddle.nn.functional.interpolate( bm, y.shape[2:], mode="bilinear", align_corners=False ) diff --git a/docker-compose.yml b/docker-compose.yml index 7f26ee8..d3599f0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ x-env: &template - image: fcb_photo_review:1.13.9 + image: fcb_photo_review:1.13.10 restart: always services: