diff --git a/visual_model_test/visual_model_test.py b/visual_model_test/visual_model_test.py index 80c0e59..fa06737 100644 --- a/visual_model_test/visual_model_test.py +++ b/visual_model_test/visual_model_test.py @@ -2,13 +2,16 @@ import os import re import sys +import tempfile import time from pprint import pprint -from paddlenlp import Taskflow -from paddlenlp.utils.doc_parser import DocParser +from photo_review.photo_review import split_image sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from paddlenlp import Taskflow +from paddlenlp.utils.doc_parser import DocParser from ucloud import ucloud @@ -39,22 +42,55 @@ def write_visual_result(image, layout=None, result=None): def visual_model_test(model_type, test_img, task_path, schema): if model_type == "ocr": - doc_parser = DocParser(layout_analysis=True) - parsed_doc = doc_parser.parse({"doc": test_img}) - write_visual_result(test_img, layout=parsed_doc["layout"]) + imgs = split_image(test_img) + layout = [] + temp_files_paths = [] + doc_parser = DocParser(layout_analysis=False) + for img in imgs: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + img["img"].save(temp_file.name) + temp_files_paths.append(temp_file.name) + parsed_doc = doc_parser.parse({"doc": temp_file.name}, expand_to_a4_size=True) + if img["x_offset"] or img["y_offset"]: + for p in parsed_doc["layout"]: + box = p[0] + box[0] += img["x_offset"] + box[1] += img["y_offset"] + box[2] += img["x_offset"] + box[3] += img["y_offset"] + layout += parsed_doc["layout"] + + write_visual_result(test_img, layout=layout) else: + docs = [] + split_result = split_image(test_img) + temp_files_paths = [] + for img in split_result: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + img["img"].save(temp_file.name) + temp_files_paths.append(temp_file.name) + docs.append({"doc": temp_file.name}) + my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, layout_analysis=False) - my_results = my_ie({"doc": test_img}) + my_results = my_ie(docs) write_visual_result(test_img, result=my_results[0]) + # 使用完临时文件后,记得清理(删除)它们 + for path in temp_files_paths: + try: + os.remove(path) + print(f"临时文件 {path} 已删除") + except Exception as e: + print(f"删除临时文件 {path} 时出错: {e}") + def batch_test(test_imgs, task_path, schema): docs = [] for test_img in test_imgs: docs.append({"doc": test_img}) my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, - layout_analysis=True, batch_size=16) + layout_analysis=False, batch_size=16) # 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}]) my_results = my_ie(docs) pprint(my_results)