将去扭曲模型转为onnx格式

This commit is contained in:
2024-08-29 14:50:17 +08:00
parent f0ae23a5e3
commit 876cd493d2
2 changed files with 11 additions and 14 deletions

View File

@@ -1,10 +1,4 @@
import paddle from onnxruntime import InferenceSession
from doc_dewarp.GeoTr import GeoTr DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])
model_path = "model/dewarp_model/best.ckpt"
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
DEWARP = GeoTr()
DEWARP.set_state_dict(state_dict)
DEWARP.eval()

View File

@@ -1,18 +1,21 @@
import cv2 import cv2
import numpy as np
import paddle import paddle
from . import DEWARP from . import DOC_TR
from .utils import to_tensor, to_image from .utils import to_tensor, to_image
def dewarp_image(image): def dewarp_image(image):
img = cv2.resize(image, (288, 288)) img = cv2.resize(image, (288, 288)).astype(np.float32)
x = to_tensor(img)
y = to_tensor(image) y = to_tensor(image)
bm = DEWARP(x)
img = np.transpose(img, (2, 0, 1))
bm = DOC_TR.run(None, {"image": img[None,]})[0]
bm = paddle.to_tensor(bm)
bm = paddle.nn.functional.interpolate( bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False bm, y.shape[2:], mode="bilinear", align_corners=False
) )
bm_nhwc = bm.transpose([0, 2, 3, 1]) bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2) out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
return to_image(out) return to_image(out)