因显存溢出频发,现改为单张图片识别

This commit is contained in:
2024-06-26 10:54:38 +08:00
parent bcd8713ee3
commit 1240c4a884
2 changed files with 34 additions and 48 deletions

View File

@@ -14,9 +14,6 @@ SLEEP_MINUTES = 5
# 是否发送报错邮件
SEND_ERROR_EMAIL = True
# 信息抽取批量处理大小
IE_BATCH_SIZE = 4
# 是否开启布局分析
LAYOUT_ANALYSIS = False
@@ -25,14 +22,12 @@ CUDA_VISIBLE_DEVICES = "1"
# 基本医保结算单
SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS,
batch_size=IE_BATCH_SIZE)
task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS)
# 出院记录
DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS,
batch_size=IE_BATCH_SIZE)
task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS)
# 费用清单
COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base",
task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS, batch_size=IE_BATCH_SIZE)
task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS)

View File

@@ -133,48 +133,39 @@ def rotate_image(img, angle):
# 关键信息提取
def information_extraction(ie, phrecs):
result = {}
docs = []
doc_phrecs = []
for phrec in phrecs:
pic_path = ucloud.get_private_url(phrec.cfjaddress)
if pic_path:
if not pic_path:
continue
split_result = split_image(pic_path)
# 同一张图的标识
identity = int(time.time())
for img in split_result:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
angle = get_image_rotation_angle(img["img"])
rotated_img = rotate_image(img["img"], angle)
cv2.imwrite(temp_file.name, rotated_img)
docs.append({"doc": temp_file.name})
doc_phrecs.append({"phrec": phrec, "rotation_angle": angle, "x_offset": img["x_offset"], "y_offset": img["y_offset"]})
if not docs:
return result
ie_results = []
ie_result = []
try:
ie_results = ie(docs)
ie_result = ie({"doc": temp_file.name})[0]
except Exception as e:
logging.error("信息抽取时出错:", e)
return result
finally:
for temp_file in docs:
try:
os.remove(temp_file["doc"])
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file['doc']} 时出错: {e}")
logging.info(f"删除临时文件 {temp_file.name} 时出错: {e}")
now = get_default_datetime()
id = int(time.time())
for i in range(len(ie_results)):
ie_result = ie_results[i]
doc_phrec = doc_phrecs[i]
phrec = doc_phrec["phrec"]
result_json = json.dumps(ie_result, ensure_ascii=False)
if len(result_json) > 5000:
result_json = result_json[:5000]
session = MysqlSession()
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=id, cfjaddress=phrec.cfjaddress,
content=result_json, rotation_angle=doc_phrec["rotation_angle"], x_offset=doc_phrec["x_offset"],
y_offset=doc_phrec["y_offset"], create_time=now, update_time=now)
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress,
content=result_json, rotation_angle=angle, x_offset=img["x_offset"], y_offset=img["y_offset"],
create_time=now, update_time=now)
session.add(zx_ocr)
session.commit()
session.close()