Files
fcb_photo_review/visual_model_test/visual_model_test.py
2024-07-24 14:02:55 +08:00

161 lines
6.6 KiB
Python

# 可视化的模型对比测试
import os
import re
import tempfile
import time
from pprint import pprint
import cv2
from paddlenlp import Taskflow
from paddlenlp.utils.doc_parser import DocParser
from paddleocr import PaddleOCR
from ucloud import ufile
from util import image_util, util
def write_visual_result(image, angle=0, layout=None, result=None):
img = image.split("?")[0]
img = re.split(r'[\\/]', img)[-1]
img_name = ""
img_type = "jpg"
last_dot_index = img.rfind(".")
if last_dot_index != -1:
img_name = img[:last_dot_index]
img_type = img[last_dot_index + 1:]
img_array = image_util.read(image)
if angle != 0:
img_array = image_util.rotate(img_array, angle)
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, img_array)
if layout:
print(layout)
DocParser.write_image_with_results(
temp_file.name,
layout=layout,
save_path="./img_result/" + img_name + "_layout." + img_type)
if result:
print(result)
DocParser.write_image_with_results(
temp_file.name,
result=result,
save_path="./img_result/" + img_name + "_result." + img_type)
os.remove(temp_file.name)
def visual_model_test(model_type, test_img, task_path, schema):
if model_type == "ocr":
imgs = image_util.split(test_img)
layout = []
temp_files_paths = []
# doc_parser = DocParser(layout_analysis=False)
for img in imgs:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
# angle = image_util.parse_rotation_angles(img["img"])[0]
angle = 0
rotated_img = image_util.rotate(img["img"], angle)
rotated_img, offset_x, offset_y = image_util.expand_to_a4_size(rotated_img, True)
cv2.imwrite(temp_file.name, rotated_img)
img["x_offset"] -= offset_x
img["y_offset"] -= offset_y
temp_files_paths.append(temp_file.name)
parsed_doc = util.get_ocr_layout(PaddleOCR(det_db_box_thresh=0.3, det_limit_side_len=2048),
temp_file.name)
# parsed_doc = doc_parser.parse({"doc": temp_file.name})["layout"]
if img["x_offset"] or img["y_offset"]:
for p in parsed_doc:
box = p[0]
box[0] += img["x_offset"]
box[1] += img["y_offset"]
box[2] += img["x_offset"]
box[3] += img["y_offset"]
layout += parsed_doc
write_visual_result(test_img, angle, layout=layout)
else:
docs = []
split_result = image_util.split(test_img)
temp_files_paths = []
for img in split_result:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
angle = int(image_util.parse_rotation_angles(img["img"])[0])
rotated_img = image_util.rotate(img["img"], angle)
cv2.imwrite(temp_file.name, rotated_img)
temp_files_paths.append(temp_file.name)
docs.append({"doc": temp_file.name})
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
layout_analysis=False)
my_results = my_ie(docs)
write_visual_result(test_img, angle, result=my_results[0])
# 使用完临时文件后,记得清理(删除)它们
for path in temp_files_paths:
try:
os.remove(path)
print(f"临时文件 {path} 已删除")
except Exception as e:
print(f"删除临时文件 {path} 时出错: {e}")
def batch_test(test_imgs, task_path, schema):
docs = []
for test_img in test_imgs:
docs.append({"doc": test_img})
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
layout_analysis=False, batch_size=16)
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
my_results = my_ie(docs)
pprint(my_results)
def main(model_type, pic_name=None):
# 开始时间
start_time = time.time()
if model_type == "ocr":
task_path = None
test_img_path = ufile.get_private_url(pic_name) if pic_name else "../test_img/IMG_20240723_131247.jpg"
schema = None
elif model_type == "settlement":
task_path = "../model/settlement_list_model"
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额",
"医保类型", "住院号", "医保结算单号码", "大写总额"]
elif model_type == "discharge":
task_path = "../model/discharge_record_model"
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
schema = ["医院", "科室", "患者姓名", "入院日期", "出院日期", "主治医生", "住院号", "年龄"]
elif model_type == "cost":
task_path = "../model/cost_list_model"
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
schema = ["患者姓名", "入院日期", "出院日期", "费用总额"]
elif model_type == "cost_detail":
task_path = "../model/cost_list_detail_model"
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
schema = {"名称": ["类别", "规格", "单价", "数量", "金额"]}
else:
print("请输入正确的类型!")
return
visual_model_test(model_type, test_img_path, task_path, schema)
# 结束时间
end_time = time.time()
pprint(f"处理时长:{end_time - start_time}")
if __name__ == '__main__':
main("ocr")
# main("settlement")
# main("discharge")
# main("cost")
# main("cost_detail")
# write_visual_result("img/PH20240428000832_1_093844_2.jpg", layout=[([508.0975609756094,
# 659.7073170731707,
# 1000,
# 745.756097560976], 'lay', 'figure')])