将去扭曲模型转为onnx格式
This commit is contained in:
@@ -1,10 +1,4 @@
|
|||||||
import paddle
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
from doc_dewarp.GeoTr import GeoTr
|
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
|
||||||
|
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])
|
||||||
model_path = "model/dewarp_model/best.ckpt"
|
|
||||||
checkpoint = paddle.load(model_path)
|
|
||||||
state_dict = checkpoint["model"]
|
|
||||||
DEWARP = GeoTr()
|
|
||||||
DEWARP.set_state_dict(state_dict)
|
|
||||||
DEWARP.eval()
|
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from . import DEWARP
|
from . import DOC_TR
|
||||||
from .utils import to_tensor, to_image
|
from .utils import to_tensor, to_image
|
||||||
|
|
||||||
|
|
||||||
def dewarp_image(image):
|
def dewarp_image(image):
|
||||||
img = cv2.resize(image, (288, 288))
|
img = cv2.resize(image, (288, 288)).astype(np.float32)
|
||||||
x = to_tensor(img)
|
|
||||||
y = to_tensor(image)
|
y = to_tensor(image)
|
||||||
bm = DEWARP(x)
|
|
||||||
|
img = np.transpose(img, (2, 0, 1))
|
||||||
|
bm = DOC_TR.run(None, {"image": img[None,]})[0]
|
||||||
|
bm = paddle.to_tensor(bm)
|
||||||
bm = paddle.nn.functional.interpolate(
|
bm = paddle.nn.functional.interpolate(
|
||||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||||
)
|
)
|
||||||
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
|
||||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||||
return to_image(out)
|
return to_image(out)
|
||||||
|
|||||||
Reference in New Issue
Block a user