将去扭曲模型转为onnx格式

This commit is contained in:
2024-08-29 14:50:17 +08:00
parent f0ae23a5e3
commit 876cd493d2
2 changed files with 11 additions and 14 deletions

View File

@@ -1,10 +1,4 @@
import paddle
from onnxruntime import InferenceSession
from doc_dewarp.GeoTr import GeoTr
model_path = "model/dewarp_model/best.ckpt"
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
DEWARP = GeoTr()
DEWARP.set_state_dict(state_dict)
DEWARP.eval()
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])