优化批量处理能力
This commit is contained in:
@@ -9,3 +9,6 @@ SLEEP_MINUTES = 5
|
||||
|
||||
# 是否发送报错邮件
|
||||
SEND_ERROR_EMAIL = True
|
||||
|
||||
# 信息抽取批量处理大小
|
||||
IE_BATCH_SIZE = 32
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import update
|
||||
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
||||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
||||
from config.mysql import MysqlSession
|
||||
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES
|
||||
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, IE_BATCH_SIZE
|
||||
from photo_review.entity.bd_yljg import BdYljg
|
||||
from photo_review.entity.bd_ylks import BdYlks
|
||||
from photo_review.entity.zx_ie_cost import ZxIeCost
|
||||
@@ -24,29 +24,36 @@ from photo_review.util.util import get_default_datetime
|
||||
|
||||
# 关键信息提取
|
||||
def information_extraction(schema, phrecs, task_path):
|
||||
results = {}
|
||||
result = {}
|
||||
docs = []
|
||||
doc_phrecs = []
|
||||
for phrec in phrecs:
|
||||
pic_path = get_private_url(phrec.cfjaddress)
|
||||
if pic_path:
|
||||
ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
||||
layout_analysis=True)
|
||||
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
||||
result = ie({"doc": pic_path})
|
||||
docs.append({"doc": pic_path})
|
||||
doc_phrecs.append(phrec)
|
||||
|
||||
result_json = json.dumps(result, ensure_ascii=False)
|
||||
if len(result_json) > 5000:
|
||||
result_json = result_json[:5000]
|
||||
# 提取完保存每张图片的结果
|
||||
now = get_default_datetime()
|
||||
session = MysqlSession()
|
||||
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress,
|
||||
content=result_json, create_time=now, update_time=now)
|
||||
session.add(zx_ocr)
|
||||
session.commit()
|
||||
session.close()
|
||||
ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
||||
layout_analysis=True, batch_size=IE_BATCH_SIZE)
|
||||
ie_results = ie(docs)
|
||||
|
||||
results.update(result[0])
|
||||
return results
|
||||
now = get_default_datetime()
|
||||
for i in range(len(ie_results)):
|
||||
ie_result = ie_results[i]
|
||||
phrec = doc_phrecs[i]
|
||||
result_json = json.dumps(ie_result, ensure_ascii=False)
|
||||
if len(result_json) > 5000:
|
||||
result_json = result_json[:5000]
|
||||
session = MysqlSession()
|
||||
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress,
|
||||
content=result_json, create_time=now, update_time=now)
|
||||
session.add(zx_ocr)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
result.update(ie_result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 从keys中获取准确率最高的value
|
||||
|
||||
@@ -23,5 +23,5 @@ def get_private_url(key):
|
||||
# url = get_ufile_handler.public_download_url(bucket, key)
|
||||
|
||||
# 获取私有空间下载url, expires为下载链接有效期,单位为秒
|
||||
url = get_ufile_handler.private_download_url(bucket, key, expires=300)
|
||||
url = get_ufile_handler.private_download_url(bucket, key, expires=3600)
|
||||
return url
|
||||
|
||||
@@ -49,6 +49,17 @@ def visual_model_test(model_type, test_img, task_path, schema):
|
||||
write_visual_result(test_img, result=my_results[0])
|
||||
|
||||
|
||||
def batch_test(test_imgs, task_path, schema):
|
||||
docs = []
|
||||
for test_img in test_imgs:
|
||||
docs.append({"doc": test_img})
|
||||
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
||||
layout_analysis=True, batch_size=16)
|
||||
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
||||
my_results = my_ie(docs)
|
||||
pprint(my_results)
|
||||
|
||||
|
||||
def main(model_type, pic_name=None):
|
||||
# 开始时间
|
||||
start_time = time.time()
|
||||
@@ -60,7 +71,7 @@ def main(model_type, pic_name=None):
|
||||
elif model_type == "settlement":
|
||||
task_path = "../config/model/settlement_list_model"
|
||||
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
|
||||
schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"]
|
||||
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"]
|
||||
elif model_type == "discharge":
|
||||
task_path = "../config/model/discharge_record_model"
|
||||
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
|
||||
|
||||
Reference in New Issue
Block a user