优化批量处理能力
This commit is contained in:
@@ -9,3 +9,6 @@ SLEEP_MINUTES = 5
|
|||||||
|
|
||||||
# 是否发送报错邮件
|
# 是否发送报错邮件
|
||||||
SEND_ERROR_EMAIL = True
|
SEND_ERROR_EMAIL = True
|
||||||
|
|
||||||
|
# 信息抽取批量处理大小
|
||||||
|
IE_BATCH_SIZE = 32
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from sqlalchemy import update
|
|||||||
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
||||||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
||||||
from config.mysql import MysqlSession
|
from config.mysql import MysqlSession
|
||||||
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES
|
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, IE_BATCH_SIZE
|
||||||
from photo_review.entity.bd_yljg import BdYljg
|
from photo_review.entity.bd_yljg import BdYljg
|
||||||
from photo_review.entity.bd_ylks import BdYlks
|
from photo_review.entity.bd_ylks import BdYlks
|
||||||
from photo_review.entity.zx_ie_cost import ZxIeCost
|
from photo_review.entity.zx_ie_cost import ZxIeCost
|
||||||
@@ -24,20 +24,26 @@ from photo_review.util.util import get_default_datetime
|
|||||||
|
|
||||||
# 关键信息提取
|
# 关键信息提取
|
||||||
def information_extraction(schema, phrecs, task_path):
|
def information_extraction(schema, phrecs, task_path):
|
||||||
results = {}
|
result = {}
|
||||||
|
docs = []
|
||||||
|
doc_phrecs = []
|
||||||
for phrec in phrecs:
|
for phrec in phrecs:
|
||||||
pic_path = get_private_url(phrec.cfjaddress)
|
pic_path = get_private_url(phrec.cfjaddress)
|
||||||
if pic_path:
|
if pic_path:
|
||||||
ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
docs.append({"doc": pic_path})
|
||||||
layout_analysis=True)
|
doc_phrecs.append(phrec)
|
||||||
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
|
||||||
result = ie({"doc": pic_path})
|
|
||||||
|
|
||||||
result_json = json.dumps(result, ensure_ascii=False)
|
ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
||||||
|
layout_analysis=True, batch_size=IE_BATCH_SIZE)
|
||||||
|
ie_results = ie(docs)
|
||||||
|
|
||||||
|
now = get_default_datetime()
|
||||||
|
for i in range(len(ie_results)):
|
||||||
|
ie_result = ie_results[i]
|
||||||
|
phrec = doc_phrecs[i]
|
||||||
|
result_json = json.dumps(ie_result, ensure_ascii=False)
|
||||||
if len(result_json) > 5000:
|
if len(result_json) > 5000:
|
||||||
result_json = result_json[:5000]
|
result_json = result_json[:5000]
|
||||||
# 提取完保存每张图片的结果
|
|
||||||
now = get_default_datetime()
|
|
||||||
session = MysqlSession()
|
session = MysqlSession()
|
||||||
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress,
|
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, cfjaddress=phrec.cfjaddress,
|
||||||
content=result_json, create_time=now, update_time=now)
|
content=result_json, create_time=now, update_time=now)
|
||||||
@@ -45,8 +51,9 @@ def information_extraction(schema, phrecs, task_path):
|
|||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
results.update(result[0])
|
result.update(ie_result)
|
||||||
return results
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# 从keys中获取准确率最高的value
|
# 从keys中获取准确率最高的value
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ def get_private_url(key):
|
|||||||
# url = get_ufile_handler.public_download_url(bucket, key)
|
# url = get_ufile_handler.public_download_url(bucket, key)
|
||||||
|
|
||||||
# 获取私有空间下载url, expires为下载链接有效期,单位为秒
|
# 获取私有空间下载url, expires为下载链接有效期,单位为秒
|
||||||
url = get_ufile_handler.private_download_url(bucket, key, expires=300)
|
url = get_ufile_handler.private_download_url(bucket, key, expires=3600)
|
||||||
return url
|
return url
|
||||||
|
|||||||
@@ -49,6 +49,17 @@ def visual_model_test(model_type, test_img, task_path, schema):
|
|||||||
write_visual_result(test_img, result=my_results[0])
|
write_visual_result(test_img, result=my_results[0])
|
||||||
|
|
||||||
|
|
||||||
|
def batch_test(test_imgs, task_path, schema):
|
||||||
|
docs = []
|
||||||
|
for test_img in test_imgs:
|
||||||
|
docs.append({"doc": test_img})
|
||||||
|
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
||||||
|
layout_analysis=True, batch_size=16)
|
||||||
|
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
||||||
|
my_results = my_ie(docs)
|
||||||
|
pprint(my_results)
|
||||||
|
|
||||||
|
|
||||||
def main(model_type, pic_name=None):
|
def main(model_type, pic_name=None):
|
||||||
# 开始时间
|
# 开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -60,7 +71,7 @@ def main(model_type, pic_name=None):
|
|||||||
elif model_type == "settlement":
|
elif model_type == "settlement":
|
||||||
task_path = "../config/model/settlement_list_model"
|
task_path = "../config/model/settlement_list_model"
|
||||||
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
|
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
|
||||||
schema = ["姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费", "医保类型"]
|
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", "医保类型"]
|
||||||
elif model_type == "discharge":
|
elif model_type == "discharge":
|
||||||
task_path = "../config/model/discharge_record_model"
|
task_path = "../config/model/discharge_record_model"
|
||||||
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
|
test_img_path = get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
|
||||||
|
|||||||
Reference in New Issue
Block a user