145 lines
5.6 KiB
Python
145 lines
5.6 KiB
Python
# 可视化的模型对比测试
|
|
import os
|
|
import re
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from pprint import pprint
|
|
|
|
import cv2
|
|
|
|
from photo_review.photo_review import split_image
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from paddlenlp import Taskflow
|
|
from paddlenlp.utils.doc_parser import DocParser
|
|
from ucloud import ucloud
|
|
|
|
|
|
def write_visual_result(image, layout=None, result=None):
|
|
img = image.split("?")[0]
|
|
img = re.split(r'[\\/]', img)[-1]
|
|
img_name = ""
|
|
img_type = "jpg"
|
|
last_dot_index = img.rfind(".")
|
|
if last_dot_index != -1:
|
|
img_name = img[:last_dot_index]
|
|
img_type = img[last_dot_index + 1:]
|
|
|
|
if layout:
|
|
print(layout)
|
|
DocParser.write_image_with_results(
|
|
image,
|
|
layout=layout,
|
|
save_path="./img_result/" + img_name + "_layout." + img_type)
|
|
|
|
if result:
|
|
print(result)
|
|
DocParser.write_image_with_results(
|
|
image,
|
|
result=result,
|
|
save_path="./img_result/" + img_name + "_result." + img_type)
|
|
|
|
|
|
def visual_model_test(model_type, test_img, task_path, schema):
|
|
if model_type == "ocr":
|
|
imgs = split_image(test_img)
|
|
layout = []
|
|
temp_files_paths = []
|
|
doc_parser = DocParser(layout_analysis=False)
|
|
for img in imgs:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
|
cv2.imwrite(temp_file.name, img["img"])
|
|
temp_files_paths.append(temp_file.name)
|
|
parsed_doc = doc_parser.parse({"doc": temp_file.name}, expand_to_a4_size=True)
|
|
if img["x_offset"] or img["y_offset"]:
|
|
for p in parsed_doc["layout"]:
|
|
box = p[0]
|
|
box[0] += img["x_offset"]
|
|
box[1] += img["y_offset"]
|
|
box[2] += img["x_offset"]
|
|
box[3] += img["y_offset"]
|
|
layout += parsed_doc["layout"]
|
|
|
|
write_visual_result(test_img, layout=layout)
|
|
else:
|
|
docs = []
|
|
split_result = split_image(test_img)
|
|
temp_files_paths = []
|
|
for img in split_result:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
|
cv2.imwrite(temp_file.name, img["img"])
|
|
temp_files_paths.append(temp_file.name)
|
|
docs.append({"doc": temp_file.name})
|
|
|
|
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
|
layout_analysis=False)
|
|
my_results = my_ie(docs)
|
|
write_visual_result(test_img, result=my_results[0])
|
|
|
|
# 使用完临时文件后,记得清理(删除)它们
|
|
for path in temp_files_paths:
|
|
try:
|
|
os.remove(path)
|
|
print(f"临时文件 {path} 已删除")
|
|
except Exception as e:
|
|
print(f"删除临时文件 {path} 时出错: {e}")
|
|
|
|
|
|
def batch_test(test_imgs, task_path, schema):
|
|
docs = []
|
|
for test_img in test_imgs:
|
|
docs.append({"doc": test_img})
|
|
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
|
layout_analysis=False, batch_size=16)
|
|
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
|
my_results = my_ie(docs)
|
|
pprint(my_results)
|
|
|
|
|
|
def main(model_type, pic_name=None):
|
|
# 开始时间
|
|
start_time = time.time()
|
|
|
|
if model_type == "ocr":
|
|
task_path = None
|
|
test_img_path = ucloud.get_private_url(pic_name) if pic_name else "img/PH20240428000832_1_093844_2.jpg"
|
|
schema = None
|
|
elif model_type == "settlement":
|
|
task_path = "../config/model/settlement_list_model"
|
|
test_img_path = ucloud.get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
|
|
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"]
|
|
elif model_type == "discharge":
|
|
task_path = "../config/model/discharge_record_model"
|
|
test_img_path = ucloud.get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
|
|
schema = ["医院", "科室", "患者姓名", "入院日期", "出院日期", "主治医生"]
|
|
elif model_type == "cost":
|
|
task_path = "../config/model/cost_list_model"
|
|
test_img_path = ucloud.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
|
|
schema = ["患者姓名", "入院日期", "出院日期", "费用总额"]
|
|
elif model_type == "cost_detail":
|
|
task_path = "../config/model/cost_list_detail_model"
|
|
test_img_path = ucloud.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
|
|
schema = {"名称": ["类别", "规格", "单价", "数量", "金额"]}
|
|
else:
|
|
print("请输入正确的类型!")
|
|
return
|
|
visual_model_test(model_type, test_img_path, task_path, schema)
|
|
|
|
# 结束时间
|
|
end_time = time.time()
|
|
pprint(f"处理时长:{end_time - start_time}秒")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# main("ocr")
|
|
main("settlement")
|
|
# main("discharge")
|
|
# main("cost")
|
|
# main("cost_detail")
|
|
# write_visual_result("img/PH20240428000832_1_093844_2.jpg", layout=[([508.0975609756094,
|
|
# 659.7073170731707,
|
|
# 1000,
|
|
# 745.756097560976], 'lay', 'figure')])
|