优化去扭曲模型的声明
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
import paddle
|
||||
|
||||
from doc_dewarp.GeoTr import GeoTr
|
||||
|
||||
model_path = "model/dewarp_model/best.ckpt"
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
DEWARP = GeoTr()
|
||||
DEWARP.set_state_dict(state_dict)
|
||||
DEWARP.eval()
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from .GeoTr import GeoTr
|
||||
from . import DEWARP
|
||||
from .utils import to_tensor, to_image
|
||||
|
||||
|
||||
def dewarp_image(image):
|
||||
model_path = "model/dewarp_model/best.ckpt"
|
||||
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
model = GeoTr()
|
||||
model.set_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
img = cv2.resize(image, (288, 288))
|
||||
x = to_tensor(img)
|
||||
y = to_tensor(image)
|
||||
bm = model(x)
|
||||
bm = DEWARP(x)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user