优化比例夸张图片的处理

This commit is contained in:
2024-06-20 14:45:15 +08:00
parent a142a62774
commit b9c7200234
2 changed files with 91 additions and 27 deletions

View File

@@ -17,16 +17,13 @@ if __name__ == '__main__':
log = logging.getLogger() log = logging.getLogger()
# 崩溃后的重试次数 # 崩溃后的重试次数
retry_time = RETRY_TIME + 1 for _ in range(RETRY_TIME + 1):
for _ in range(retry_time):
try: try:
log.info("照片审核开始") log.info("照片审核关键信息抽取】开始")
main() main()
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
if SEND_ERROR_EMAIL: if SEND_ERROR_EMAIL:
send_an_error_email(program_name='照片审核关键信息抽取脚本', error_name=repr(e), send_an_error_email(program_name='照片审核关键信息抽取脚本', error_name=repr(e), error_detail=traceback.format_exc())
error_detail=traceback.format_exc())
# 释放显存 # 释放显存
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
continue

View File

@@ -1,12 +1,19 @@
import json import json
import logging import logging
import math
import os import os
import sys import sys
from time import sleep import tempfile
from io import BytesIO
import paddle import paddle
from sqlalchemy import update import requests
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from PIL import Image
from time import sleep
from sqlalchemy import update
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \ from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
from config.mysql import MysqlSession from config.mysql import MysqlSession
@@ -25,7 +32,55 @@ from photo_review.util.data_util import handle_date, handle_decimal, handle_depa
from photo_review.util.util import get_default_datetime from photo_review.util.util import get_default_datetime
from ucloud import ucloud from ucloud import ucloud
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 获取图片
def open_image_from_url(url):
# 发送HTTP请求获取图片数据
response = requests.get(url)
# 将响应内容转化为BytesIO对象以便PIL处理
image_stream = BytesIO(response.content)
# 使用PIL的Image.open方法打开图像
image = Image.open(image_stream)
return image
# 分割大图片
def split_image(img_path, max_ratio=2.82, best_ration=1.41, overlap=0.05):
split_result = []
# 打开图片
img = open_image_from_url(img_path)
# 获取图片的宽度和高度
width, height = img.size
# 计算宽高比
ratio = max(width, height) / min(width, height)
# 检查是否需要裁剪
if ratio > max_ratio:
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
new_ratio = best_ration - overlap
if width < height: # 高度是较长边
for i in range(math.ceil(height / (width * new_ratio))):
offset = round(width * new_ratio * i)
cropped_img = img.crop((0, offset, width, round(offset + width * best_ration)))
# 统一转为RGB这样可以正确保存为jpg格式
cropped_img = cropped_img.convert("RGB")
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
else: # 宽度是较长边
for i in range(math.ceil(width / (height * new_ratio))):
offset = round(height * new_ratio * i)
cropped_img = img.crop((offset, 0, round(offset + height * best_ration), height))
# 统一转为RGB这样可以正确保存为jpg格式
cropped_img = cropped_img.convert("RGB")
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
else:
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
return split_result
# 合并信息抽取结果
def merge_result(result1, result2):
for key in result2:
result1[key] = result1.get(key, []) + result2[key]
return result1
# 关键信息提取 # 关键信息提取
@@ -36,16 +91,27 @@ def information_extraction(ie, phrecs):
for phrec in phrecs: for phrec in phrecs:
pic_path = ucloud.get_private_url(phrec.cfjaddress) pic_path = ucloud.get_private_url(phrec.cfjaddress)
if pic_path: if pic_path:
docs.append({"doc": pic_path}) split_result = split_image(pic_path)
for img in split_result:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
img["img"].save(temp_file.name)
docs.append({"doc": temp_file.name})
doc_phrecs.append(phrec) doc_phrecs.append(phrec)
if not docs: if not docs:
return result return result
ie_results = []
try: try:
ie_results = ie(docs) ie_results = ie(docs)
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
return result return result
finally:
for temp_file in docs:
try:
os.remove(temp_file["doc"])
except Exception as e:
logging.info(f"删除临时文件 {temp_file['doc']} 时出错: {e}")
now = get_default_datetime() now = get_default_datetime()
for i in range(len(ie_results)): for i in range(len(ie_results)):
@@ -61,7 +127,7 @@ def information_extraction(ie, phrecs):
session.commit() session.commit()
session.close() session.close()
result.update(ie_result) result = merge_result(result, ie_result)
return result return result
@@ -71,15 +137,16 @@ def get_best_value_in_keys(source, keys):
# 最终结果 # 最终结果
result = None result = None
# 最大可能性 # 最大可能性
most_probability = 0 best_probability = 0
for key in keys: for key in keys:
values = source.get(key) values = source.get(key)
if values: if values:
for value in values: for value in values:
text = value.get("text") text = value.get("text")
probability = value.get("probability") probability = value.get("probability")
if text and probability > most_probability: if text and probability > best_probability:
result = text result = text
best_probability = probability
return result return result
@@ -89,10 +156,12 @@ def get_values_of_keys(source, keys):
for key in keys: for key in keys:
value = source.get(key) value = source.get(key)
if value: if value:
value = value[0].get("text") for v in value:
if value: v = v.get("text")
result.append(value) if v:
return result result.append(v)
# 去重
return list(set(result))
def save_or_update_ie(table, pk_phhd, data): def save_or_update_ie(table, pk_phhd, data):
@@ -214,10 +283,8 @@ def main():
# 持续检测新案子 # 持续检测新案子
while 1: while 1:
session = MysqlSession() session = MysqlSession()
phhds = session.query(ZxPhhd.pk_phhd) \ # 查询需要识别的案子
.filter(ZxPhhd.exsuccess_flag == '1') \ phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all()
.limit(PHHD_BATCH_SIZE) \
.all()
session.close() session.close()
if phhds: if phhds:
for phhd in phhds: for phhd in phhds:
@@ -226,14 +293,14 @@ def main():
# 识别完成更新标识 # 识别完成更新标识
session = MysqlSession() session = MysqlSession()
stmt = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8)) update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8))
session.execute(stmt) session.execute(update_flag)
session.commit() session.commit()
session.close() session.close()
# 完成一个案子释放显存
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
else: else:
# 没有查询到新案子,等待一段时间后再查 # 没有查询到新案子,等待一段时间后再查
sleep_minutes = SLEEP_MINUTES
log = logging.getLogger() log = logging.getLogger()
log.info(f"暂未查询到新案子,等待{sleep_minutes}分钟...") log.info(f"暂未查询到新案子,等待{SLEEP_MINUTES}分钟...")
sleep(sleep_minutes * 60) sleep(SLEEP_MINUTES * 60)