优化比例夸张图片的处理
This commit is contained in:
9
main.py
9
main.py
@@ -17,16 +17,13 @@ if __name__ == '__main__':
|
|||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
# 崩溃后的重试次数
|
# 崩溃后的重试次数
|
||||||
retry_time = RETRY_TIME + 1
|
for _ in range(RETRY_TIME + 1):
|
||||||
for _ in range(retry_time):
|
|
||||||
try:
|
try:
|
||||||
log.info("照片审核开始")
|
log.info("【照片审核关键信息抽取】开始")
|
||||||
main()
|
main()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
if SEND_ERROR_EMAIL:
|
if SEND_ERROR_EMAIL:
|
||||||
send_an_error_email(program_name='照片审核关键信息抽取脚本', error_name=repr(e),
|
send_an_error_email(program_name='照片审核关键信息抽取脚本', error_name=repr(e), error_detail=traceback.format_exc())
|
||||||
error_detail=traceback.format_exc())
|
|
||||||
# 释放显存
|
# 释放显存
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
continue
|
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from time import sleep
|
import tempfile
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from sqlalchemy import update
|
import requests
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from time import sleep
|
||||||
|
from sqlalchemy import update
|
||||||
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
||||||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
||||||
from config.mysql import MysqlSession
|
from config.mysql import MysqlSession
|
||||||
@@ -25,7 +32,55 @@ from photo_review.util.data_util import handle_date, handle_decimal, handle_depa
|
|||||||
from photo_review.util.util import get_default_datetime
|
from photo_review.util.util import get_default_datetime
|
||||||
from ucloud import ucloud
|
from ucloud import ucloud
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
# 获取图片
|
||||||
|
def open_image_from_url(url):
|
||||||
|
# 发送HTTP请求获取图片数据
|
||||||
|
response = requests.get(url)
|
||||||
|
# 将响应内容转化为BytesIO对象,以便PIL处理
|
||||||
|
image_stream = BytesIO(response.content)
|
||||||
|
# 使用PIL的Image.open方法打开图像
|
||||||
|
image = Image.open(image_stream)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# 分割大图片
|
||||||
|
def split_image(img_path, max_ratio=2.82, best_ration=1.41, overlap=0.05):
|
||||||
|
split_result = []
|
||||||
|
# 打开图片
|
||||||
|
img = open_image_from_url(img_path)
|
||||||
|
# 获取图片的宽度和高度
|
||||||
|
width, height = img.size
|
||||||
|
# 计算宽高比
|
||||||
|
ratio = max(width, height) / min(width, height)
|
||||||
|
# 检查是否需要裁剪
|
||||||
|
if ratio > max_ratio:
|
||||||
|
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
|
||||||
|
new_ratio = best_ration - overlap
|
||||||
|
if width < height: # 高度是较长边
|
||||||
|
for i in range(math.ceil(height / (width * new_ratio))):
|
||||||
|
offset = round(width * new_ratio * i)
|
||||||
|
cropped_img = img.crop((0, offset, width, round(offset + width * best_ration)))
|
||||||
|
# 统一转为RGB,这样可以正确保存为jpg格式
|
||||||
|
cropped_img = cropped_img.convert("RGB")
|
||||||
|
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
|
||||||
|
else: # 宽度是较长边
|
||||||
|
for i in range(math.ceil(width / (height * new_ratio))):
|
||||||
|
offset = round(height * new_ratio * i)
|
||||||
|
cropped_img = img.crop((offset, 0, round(offset + height * best_ration), height))
|
||||||
|
# 统一转为RGB,这样可以正确保存为jpg格式
|
||||||
|
cropped_img = cropped_img.convert("RGB")
|
||||||
|
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
|
||||||
|
else:
|
||||||
|
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
|
||||||
|
return split_result
|
||||||
|
|
||||||
|
|
||||||
|
# 合并信息抽取结果
|
||||||
|
def merge_result(result1, result2):
|
||||||
|
for key in result2:
|
||||||
|
result1[key] = result1.get(key, []) + result2[key]
|
||||||
|
return result1
|
||||||
|
|
||||||
|
|
||||||
# 关键信息提取
|
# 关键信息提取
|
||||||
@@ -36,16 +91,27 @@ def information_extraction(ie, phrecs):
|
|||||||
for phrec in phrecs:
|
for phrec in phrecs:
|
||||||
pic_path = ucloud.get_private_url(phrec.cfjaddress)
|
pic_path = ucloud.get_private_url(phrec.cfjaddress)
|
||||||
if pic_path:
|
if pic_path:
|
||||||
docs.append({"doc": pic_path})
|
split_result = split_image(pic_path)
|
||||||
|
for img in split_result:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||||
|
img["img"].save(temp_file.name)
|
||||||
|
docs.append({"doc": temp_file.name})
|
||||||
doc_phrecs.append(phrec)
|
doc_phrecs.append(phrec)
|
||||||
if not docs:
|
if not docs:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
ie_results = []
|
||||||
try:
|
try:
|
||||||
ie_results = ie(docs)
|
ie_results = ie(docs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
return result
|
return result
|
||||||
|
finally:
|
||||||
|
for temp_file in docs:
|
||||||
|
try:
|
||||||
|
os.remove(temp_file["doc"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"删除临时文件 {temp_file['doc']} 时出错: {e}")
|
||||||
|
|
||||||
now = get_default_datetime()
|
now = get_default_datetime()
|
||||||
for i in range(len(ie_results)):
|
for i in range(len(ie_results)):
|
||||||
@@ -61,7 +127,7 @@ def information_extraction(ie, phrecs):
|
|||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
result.update(ie_result)
|
result = merge_result(result, ie_result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -71,15 +137,16 @@ def get_best_value_in_keys(source, keys):
|
|||||||
# 最终结果
|
# 最终结果
|
||||||
result = None
|
result = None
|
||||||
# 最大可能性
|
# 最大可能性
|
||||||
most_probability = 0
|
best_probability = 0
|
||||||
for key in keys:
|
for key in keys:
|
||||||
values = source.get(key)
|
values = source.get(key)
|
||||||
if values:
|
if values:
|
||||||
for value in values:
|
for value in values:
|
||||||
text = value.get("text")
|
text = value.get("text")
|
||||||
probability = value.get("probability")
|
probability = value.get("probability")
|
||||||
if text and probability > most_probability:
|
if text and probability > best_probability:
|
||||||
result = text
|
result = text
|
||||||
|
best_probability = probability
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -89,10 +156,12 @@ def get_values_of_keys(source, keys):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
value = source.get(key)
|
value = source.get(key)
|
||||||
if value:
|
if value:
|
||||||
value = value[0].get("text")
|
for v in value:
|
||||||
if value:
|
v = v.get("text")
|
||||||
result.append(value)
|
if v:
|
||||||
return result
|
result.append(v)
|
||||||
|
# 去重
|
||||||
|
return list(set(result))
|
||||||
|
|
||||||
|
|
||||||
def save_or_update_ie(table, pk_phhd, data):
|
def save_or_update_ie(table, pk_phhd, data):
|
||||||
@@ -214,10 +283,8 @@ def main():
|
|||||||
# 持续检测新案子
|
# 持续检测新案子
|
||||||
while 1:
|
while 1:
|
||||||
session = MysqlSession()
|
session = MysqlSession()
|
||||||
phhds = session.query(ZxPhhd.pk_phhd) \
|
# 查询需要识别的案子
|
||||||
.filter(ZxPhhd.exsuccess_flag == '1') \
|
phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all()
|
||||||
.limit(PHHD_BATCH_SIZE) \
|
|
||||||
.all()
|
|
||||||
session.close()
|
session.close()
|
||||||
if phhds:
|
if phhds:
|
||||||
for phhd in phhds:
|
for phhd in phhds:
|
||||||
@@ -226,14 +293,14 @@ def main():
|
|||||||
|
|
||||||
# 识别完成更新标识
|
# 识别完成更新标识
|
||||||
session = MysqlSession()
|
session = MysqlSession()
|
||||||
stmt = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8))
|
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8))
|
||||||
session.execute(stmt)
|
session.execute(update_flag)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
# 完成一个案子释放显存
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
# 没有查询到新案子,等待一段时间后再查
|
# 没有查询到新案子,等待一段时间后再查
|
||||||
sleep_minutes = SLEEP_MINUTES
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
log.info(f"暂未查询到新案子,等待{sleep_minutes}分钟...")
|
log.info(f"暂未查询到新案子,等待{SLEEP_MINUTES}分钟...")
|
||||||
sleep(sleep_minutes * 60)
|
sleep(SLEEP_MINUTES * 60)
|
||||||
|
|||||||
Reference in New Issue
Block a user