优化批量处理能力

This commit is contained in:
2024-06-06 10:39:34 +08:00
parent adbf93eca1
commit 9a35071b93
4 changed files with 42 additions and 21 deletions

View File

@@ -9,3 +9,6 @@ SLEEP_MINUTES = 5
# 是否发送报错邮件 # 是否发送报错邮件
SEND_ERROR_EMAIL = True SEND_ERROR_EMAIL = True
# 信息抽取批量处理大小
IE_BATCH_SIZE = 32

View File

@@ -8,7 +8,7 @@ from sqlalchemy import update
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \ from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
from config.mysql import MysqlSession from config.mysql import MysqlSession
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, IE_BATCH_SIZE
from photo_review.entity.bd_yljg import BdYljg from photo_review.entity.bd_yljg import BdYljg
from photo_review.entity.bd_ylks import BdYlks from photo_review.entity.bd_ylks import BdYlks
from photo_review.entity.zx_ie_cost import ZxIeCost from photo_review.entity.zx_ie_cost import ZxIeCost
@@ -24,20 +24,26 @@ from photo_review.util.util import get_default_datetime
# 关键信息提取 # 关键信息提取
def information_extraction(schema, phrecs, task_path): def information_extraction(schema, phrecs, task_path):
results = {} result = {}
docs = []
doc_phrecs = []
for phrec in phrecs: for phrec in phrecs:
pic_path = get_private_url(phrec.cfjaddress) pic_path = get_private_url(phrec.cfjaddress)
if pic_path: if pic_path:
ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, docs.append({"doc": pic_path})
layout_analysis=True) doc_phrecs.append(phrec)
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
result = ie({"doc": pic_path})
result_json = json.dumps(result, ensure_ascii=False) ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
layout_analysis=True, batch_size=IE_BATCH_SIZE)
ie_results = ie(docs)
now = get_default_datetime()
for i in range(len(ie_results)):
ie_result = ie_results[i]
phrec = doc_phrecs[i]
result_json = json.dumps(ie_result, ensure_ascii=False)
if len(result_json) > 5000: if len(result_json) > 5000:
result_json = result_json[:5000] result_json = result_json[:5000]
# 提取完保存每张图片的结果
now = get_default_datetime()
session = MysqlSession() session = MysqlSession()
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress, zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress,
content=result_json, create_time=now, update_time=now) content=result_json, create_time=now, update_time=now)
@@ -45,8 +51,9 @@ def information_extraction(schema, phrecs, task_path):
session.commit() session.commit()
session.close() session.close()
results.update(result[0]) result.update(ie_result)
return results
return result
# 从keys中获取准确率最高的value # 从keys中获取准确率最高的value

View File

@@ -23,5 +23,5 @@ def get_private_url(key):
# url = get_ufile_handler.public_download_url(bucket, key) # url = get_ufile_handler.public_download_url(bucket, key)
# 获取私有空间下载url, expires为下载链接有效期单位为秒 # 获取私有空间下载url, expires为下载链接有效期单位为秒
url = get_ufile_handler.private_download_url(bucket, key, expires=300) url = get_ufile_handler.private_download_url(bucket, key, expires=3600)
return url return url

View File

@@ -49,6 +49,17 @@ def visual_model_test(model_type, test_img, task_path, schema):
write_visual_result(test_img, result=my_results[0]) write_visual_result(test_img, result=my_results[0])
def batch_test(test_imgs, task_path, schema):
docs = []
for test_img in test_imgs:
docs.append({"doc": test_img})
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
layout_analysis=True, batch_size=16)
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
my_results = my_ie(docs)
pprint(my_results)
def main(model_type, pic_name=None): def main(model_type, pic_name=None):
# 开始时间 # 开始时间
start_time = time.time() start_time = time.time()
@@ -60,7 +71,7 @@ def main(model_type, pic_name=None):
elif model_type == "settlement": elif model_type == "settlement":
task_path = "../config/model/settlement_list_model" task_path = "../config/model/settlement_list_model"
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"] schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"]
elif model_type == "discharge": elif model_type == "discharge":
task_path = "../config/model/discharge_record_model" task_path = "../config/model/discharge_record_model"
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg" test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"