From 876cd493d22d6789ad2615ba30b67216b1ca5d24 Mon Sep 17 00:00:00 2001 From: liuyebo <1515783401@qq.com> Date: Thu, 29 Aug 2024 14:50:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=E5=8E=BB=E6=89=AD=E6=9B=B2=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=BD=AC=E4=B8=BAonnx=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc_dewarp/__init__.py | 12 +++--------- doc_dewarp/dewarp.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/doc_dewarp/__init__.py b/doc_dewarp/__init__.py index 3f477b4..c98d39c 100644 --- a/doc_dewarp/__init__.py +++ b/doc_dewarp/__init__.py @@ -1,10 +1,4 @@ -import paddle +from onnxruntime import InferenceSession -from doc_dewarp.GeoTr import GeoTr - -model_path = "model/dewarp_model/best.ckpt" -checkpoint = paddle.load(model_path) -state_dict = checkpoint["model"] -DEWARP = GeoTr() -DEWARP.set_state_dict(state_dict) -DEWARP.eval() +DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx", + providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}]) diff --git a/doc_dewarp/dewarp.py b/doc_dewarp/dewarp.py index 34548bd..923f56f 100644 --- a/doc_dewarp/dewarp.py +++ b/doc_dewarp/dewarp.py @@ -1,18 +1,21 @@ import cv2 +import numpy as np import paddle -from . import DEWARP +from . import DOC_TR from .utils import to_tensor, to_image def dewarp_image(image): - img = cv2.resize(image, (288, 288)) - x = to_tensor(img) + img = cv2.resize(image, (288, 288)).astype(np.float32) y = to_tensor(image) - bm = DEWARP(x) + + img = np.transpose(img, (2, 0, 1)) + bm = DOC_TR.run(None, {"image": img[None,]})[0] + bm = paddle.to_tensor(bm) bm = paddle.nn.functional.interpolate( bm, y.shape[2:], mode="bilinear", align_corners=False ) - bm_nhwc = bm.transpose([0, 2, 3, 1]) + bm_nhwc = np.transpose(bm, (0, 2, 3, 1)) out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2) return to_image(out)