优化照片识别功能架构

This commit is contained in:
2024-07-17 14:15:02 +08:00
parent 909720cec9
commit dd895f98b5
3 changed files with 24 additions and 175 deletions

View File

@@ -3,7 +3,7 @@ import traceback
from auto_email.error_email import send_error_email from auto_email.error_email import send_error_email
from log import LOGGING_CONFIG from log import LOGGING_CONFIG
from photo_review import photo_review, SEND_ERROR_EMAIL from photo_review import photo_review, SEND_ERROR_EMAIL, RETRY_TIME
# 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生 # 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生
if __name__ == '__main__': if __name__ == '__main__':
@@ -11,7 +11,7 @@ if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG) logging.config.dictConfig(LOGGING_CONFIG)
# 崩溃后的重试次数 # 崩溃后的重试次数
for _ in range(2): for _ in range(RETRY_TIME + 1):
try: try:
logging.info(f"{program_name}】开始运行") logging.info(f"{program_name}】开始运行")
photo_review.main() photo_review.main()

View File

@@ -60,10 +60,12 @@ COST_LIST_SCHEMA = PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPE
模型配置 模型配置
""" """
SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base", SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
task_path="config/model/settlement_list_model", layout_analysis=False, precision='fp16') task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS,
precision='fp16')
DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base", DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
task_path="config/model/discharge_record_model", layout_analysis=False, precision='fp16') task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS,
precision='fp16')
COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", device_id=1, COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", device_id=1,
task_path="config/model/cost_list_model", layout_analysis=False, precision='fp16') task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
OCR = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False, gpu_id=1) OCR = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False, gpu_id=1)

View File

@@ -1,89 +1,26 @@
import json import json
import logging import logging
import math
import os import os
import sys
import tempfile import tempfile
import time import time
import urllib.request
from time import sleep from time import sleep
import cv2 import cv2
import numpy as np
import paddleclas
import requests import requests
from sqlalchemy import update
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from db import MysqlSession from db import MysqlSession
from db.mysql import BdYljg, BdYlks, ZxOcr, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec from db.mysql import BdYljg, BdYlks, ZxOcr, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec
from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \ from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \ PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES
from sqlalchemy import update
from ucloud import ucloud from ucloud import ucloud
from util import image_util, util
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \ from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_id, handle_age handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_id, handle_age
from util.util import get_default_datetime from util.util import get_default_datetime
# 获取图片
def open_image(img_path):
if img_path.startswith("http"):
# 发送HTTP请求并获取图像数据
resp = urllib.request.urlopen(img_path)
# 将数据读取为字节流
image_data = resp.read()
# 将字节流转换为NumPy数组
image_np = np.frombuffer(image_data, np.uint8)
# 解码NumPy数组为OpenCV图像格式
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
else:
image = cv2.imread(img_path)
return image
# 分割大图片
def split_image(img_path, max_ratio=1.41, best_ration=1.41, overlap=0.05):
split_result = []
# 打开图片
img = open_image(img_path)
# 获取图片的宽度和高度
height, width = img.shape[:2]
# 计算宽高比
ratio = max(width, height) / min(width, height)
# 检查是否需要裁剪
if ratio > max_ratio:
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
new_ratio = best_ration - overlap
if width < height:
# 高度是较长边
cropped_width = width * best_ration
for i in range(math.ceil(height / (width * new_ratio))):
offset = round(width * new_ratio * i)
# 参数形式为[y1:y2, x1:x2]
cropped_img = img[offset:round(offset + cropped_width), 0:width]
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, round(cropped_width - last_img.shape[0]), 0, 0,
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
# 宽度是较长边
cropped_height = height * best_ration
for i in range(math.ceil(width / (height * new_ratio))):
offset = round(height * new_ratio * i)
cropped_img = img[0:height, offset:round(offset + cropped_height)]
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, 0, 0, round(cropped_height - last_img.shape[1]),
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
return split_result
# 合并信息抽取结果 # 合并信息抽取结果
def merge_result(result1, result2): def merge_result(result1, result2):
for key in result2: for key in result2:
@@ -91,91 +28,13 @@ def merge_result(result1, result2):
return result1 return result1
# 获取图片旋转角度
def get_image_rotation_angle(img):
angle = 0
model = paddleclas.PaddleClas(model_name="text_image_orientation")
result = model.predict(input_data=img)
try:
angle = int(next(result)[0]["label_names"][0])
except Exception as e:
logging.error("获取图片旋转角度失败", exc_info=e)
return angle
# 获取图片旋转角度
def get_image_rotation_angles(img):
angles = ['0', '90']
model = paddleclas.PaddleClas(model_name="text_image_orientation")
result = model.predict(input_data=img)
try:
angles = next(result)[0]["label_names"]
except Exception as e:
logging.error("获取图片旋转角度失败", exc_info=e)
return angles
# 旋转图片
def rotate_image(img, angle):
if angle == 0:
return img
height, width, _ = img.shape
if angle == 180:
new_width = width
new_height = height
else:
new_width = height
new_height = width
# 绕图像的中心旋转
# 参数:旋转中心 旋转度数 scale
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)
# 旋转后平移
matrix[0, 2] += (new_width - width) / 2
matrix[1, 2] += (new_height - height) / 2
# 参数:原始图像 旋转参数 元素图像宽高
rotated = cv2.warpAffine(img, matrix, (new_width, new_height))
return rotated
# 获取图片OCR并将其box转为两点矩形框
def get_ocr_layout(ocr, img_path):
def _get_box(old_box):
new_box = [
min(old_box[0][0], old_box[3][0]), # x1
min(old_box[0][1], old_box[1][1]), # y1
max(old_box[1][0], old_box[2][0]), # x2
max(old_box[2][1], old_box[3][1]), # y2
]
return new_box
def _normal_box(box_data):
# Ensure the height and width of bbox are greater than zero
if box_data[3] - box_data[1] < 0 or box_data[2] - box_data[0] < 0:
return False
return True
layout = []
ocr_result = ocr.ocr(img_path, cls=False)
ocr_result = ocr_result[0]
if not ocr_result:
return layout
for segment in ocr_result:
box = segment[0]
box = _get_box(box)
if not _normal_box(box):
continue
text = segment[1][0]
layout.append((box, text))
return layout
def ie_temp_image(ie, ocr, image): def ie_temp_image(ie, ocr, image):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image) cv2.imwrite(temp_file.name, image)
ie_result = [] ie_result = []
try: try:
layout = get_ocr_layout(ocr, temp_file.name) layout = util.get_ocr_layout(ocr, temp_file.name)
if not layout: if not layout:
# 无识别结果 # 无识别结果
ie_result = [] ie_result = []
@@ -213,17 +72,17 @@ def information_extraction(ie, phrecs):
# 同一批图的标识 # 同一批图的标识
identity = int(time.time()) identity = int(time.time())
for phrec in phrecs: for phrec in phrecs:
pic_path = ucloud.get_private_url(phrec.cfjaddress) img_path = ucloud.get_private_url(phrec.cfjaddress)
if not pic_path: if not img_path:
continue continue
split_result = split_image(pic_path) split_results = image_util.split(img_path)
for img in split_result: for split_result in split_results:
angles = get_image_rotation_angles(img["img"]) angles = image_util.parse_rotation_angles(split_result["img"])
rotated_img = rotate_image(img["img"], int(angles[0])) rotated_img = image_util.rotate(split_result["img"], int(angles[0]))
ie_results = [{"result": ie_temp_image(ie, OCR, rotated_img), "angle": angles[0]}] ie_results = [{"result": ie_temp_image(ie, OCR, rotated_img), "angle": angles[0]}]
if not ie_results[0] or len(ie_results[0]) < len(ie.kwargs.get("schema")): if not ie_results[0] or len(ie_results[0]) < len(ie.kwargs.get("schema")):
rotated_img = rotate_image(img["img"], int(angles[1])) rotated_img = image_util.rotate(split_result["img"], int(angles[1]))
ie_results.append({"result": ie_temp_image(ie, OCR, rotated_img), "angle": angles[1]}) ie_results.append({"result": ie_temp_image(ie, OCR, rotated_img), "angle": angles[1]})
now = get_default_datetime() now = get_default_datetime()
@@ -233,8 +92,9 @@ def information_extraction(ie, phrecs):
result_json = result_json[:5000] result_json = result_json[:5000]
session = MysqlSession() session = MysqlSession()
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress, zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress,
content=result_json, rotation_angle=ie_result["angle"], x_offset=img["x_offset"], content=result_json, rotation_angle=ie_result["angle"],
y_offset=img["y_offset"], create_time=now, update_time=now) x_offset=split_result["x_offset"], y_offset=split_result["y_offset"], create_time=now,
update_time=now)
session.add(zx_ocr) session.add(zx_ocr)
session.commit() session.commit()
session.close() session.close()
@@ -383,20 +243,15 @@ def cost_task(pk_phhd, cost_list):
save_or_update_ie(ZxIeCost, pk_phhd, cost_data) save_or_update_ie(ZxIeCost, pk_phhd, cost_data)
def settlement_and_discharge_task(pk_phhd, settlement_list, discharge_record):
settlement_task(pk_phhd, settlement_list)
discharge_task(pk_phhd, discharge_record)
def photo_review(pk_phhd): def photo_review(pk_phhd):
settlement_list = [] settlement_list = []
discharge_record = [] discharge_record = []
cost_list = [] cost_list = []
session = MysqlSession() session = MysqlSession()
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress) \ phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
.filter(ZxPhrec.pk_phhd == pk_phhd) \ ZxPhrec.pk_phhd == pk_phhd
.all() ).all()
session.close() session.close()
for phrec in phrecs: for phrec in phrecs:
if phrec.cRectype == "1": if phrec.cRectype == "1":
@@ -406,21 +261,14 @@ def photo_review(pk_phhd):
elif phrec.cRectype == "4": elif phrec.cRectype == "4":
cost_list.append(phrec) cost_list.append(phrec)
# with concurrent.futures.ThreadPoolExecutor() as executor:
# executor.submit(settlement_task, pk_phhd, settlement_list)
# executor.submit(discharge_task, pk_phhd, discharge_record)
# # executor.submit(settlement_and_discharge_task, pk_phhd, settlement_list, discharge_record)
# executor.submit(cost_task, pk_phhd, cost_list)
settlement_task(pk_phhd, settlement_list) settlement_task(pk_phhd, settlement_list)
discharge_task(pk_phhd, discharge_record) discharge_task(pk_phhd, discharge_record)
cost_task(pk_phhd, cost_list) cost_task(pk_phhd, cost_list)
def main(): def main():
# 持续检测新案子
while 1: while 1:
session = MysqlSession() session = MysqlSession()
# 查询需要识别的案子
phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all() phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all()
# 将状态改为正在识别中 # 将状态改为正在识别中
pk_phhd_values = [phhd.pk_phhd for phhd in phhds] pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
@@ -442,6 +290,5 @@ def main():
session.close() session.close()
else: else:
# 没有查询到新案子,等待一段时间后再查 # 没有查询到新案子,等待一段时间后再查
log = logging.getLogger() logging.info(f"暂未查询到需要识别的案子,等待{SLEEP_MINUTES}分钟...")
log.info(f"暂未查询到新案子,等待{SLEEP_MINUTES}分钟...")
sleep(SLEEP_MINUTES * 60) sleep(SLEEP_MINUTES * 60)