DocTr去扭曲
This commit is contained in:
69
doc_dewarp/predict.py
Normal file
69
doc_dewarp/predict.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image, to_tensor
|
||||
|
||||
|
||||
def run(args):
|
||||
image_path = args.image
|
||||
model_path = args.model
|
||||
output_path = args.output
|
||||
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
model = GeoTr()
|
||||
model.set_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
img_org = cv2.imread(image_path)
|
||||
img = cv2.resize(img_org, (288, 288))
|
||||
x = to_tensor(img)
|
||||
y = to_tensor(img_org)
|
||||
bm = model(x)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
out_image = to_image(out)
|
||||
cv2.imwrite(output_path, out_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="predict")
|
||||
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"-i",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of image",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
run(args)
|
||||
Reference in New Issue
Block a user