406 lines
16 KiB
Python
406 lines
16 KiB
Python
import concurrent.futures
|
||
import logging
|
||
import math
|
||
import os
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
import urllib.request
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import paddleclas
|
||
import requests
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from time import sleep
|
||
from sqlalchemy import update
|
||
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR
|
||
from config.mysql import MysqlSession
|
||
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES
|
||
from photo_review.entity.bd_yljg import BdYljg
|
||
from photo_review.entity.bd_ylks import BdYlks
|
||
from photo_review.entity.zx_ie_cost import ZxIeCost
|
||
from photo_review.entity.zx_ie_discharge import ZxIeDischarge
|
||
from photo_review.entity.zx_ie_settlement import ZxIeSettlement
|
||
from photo_review.entity.zx_phhd import ZxPhhd
|
||
from photo_review.entity.zx_phrec import ZxPhrec
|
||
from photo_review.util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
|
||
handle_insurance_type, handle_original_data, handle_hospital, handle_department
|
||
from photo_review.util.util import get_default_datetime
|
||
from photo_review.enumeration.task import TaskEnum
|
||
|
||
|
||
# 获取图片
|
||
def open_image(img_path):
|
||
if img_path.startswith("http"):
|
||
# 发送HTTP请求并获取图像数据
|
||
resp = urllib.request.urlopen(img_path)
|
||
# 将数据读取为字节流
|
||
image_data = resp.read()
|
||
# 将字节流转换为NumPy数组
|
||
image_np = np.frombuffer(image_data, np.uint8)
|
||
# 解码NumPy数组为OpenCV图像格式
|
||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||
else:
|
||
image = cv2.imread(img_path)
|
||
return image
|
||
|
||
|
||
# 分割大图片
|
||
def split_image(img_path, max_ratio=2.82, best_ration=1.41, overlap=0.05):
|
||
split_result = []
|
||
# 打开图片
|
||
img = open_image(img_path)
|
||
# 获取图片的宽度和高度
|
||
height, width = img.shape[:2]
|
||
# 计算宽高比
|
||
ratio = max(width, height) / min(width, height)
|
||
# 检查是否需要裁剪
|
||
if ratio > max_ratio:
|
||
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
|
||
new_ratio = best_ration - overlap
|
||
if width < height:
|
||
# 高度是较长边
|
||
cropped_width = width * best_ration
|
||
for i in range(math.ceil(height / (width * new_ratio))):
|
||
offset = round(width * new_ratio * i)
|
||
# 参数形式为[y1:y2, x1:x2]
|
||
cropped_img = img[offset:round(offset + cropped_width), 0:width]
|
||
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
|
||
# 最后一次裁剪时不足的部分填充黑色
|
||
last_img = split_result[-1]["img"]
|
||
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, round(cropped_width - last_img.shape[0]), 0, 0,
|
||
cv2.BORDER_CONSTANT, value=(0, 0, 0))
|
||
else:
|
||
# 宽度是较长边
|
||
cropped_height = height * best_ration
|
||
for i in range(math.ceil(width / (height * new_ratio))):
|
||
offset = round(height * new_ratio * i)
|
||
cropped_img = img[0:height, offset:round(offset + cropped_height)]
|
||
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
|
||
# 最后一次裁剪时不足的部分填充黑色
|
||
last_img = split_result[-1]["img"]
|
||
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, 0, 0, round(cropped_height - last_img.shape[1]),
|
||
cv2.BORDER_CONSTANT, value=(0, 0, 0))
|
||
else:
|
||
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
|
||
return split_result
|
||
|
||
|
||
# 合并信息抽取结果
|
||
def merge_result(result1, result2):
|
||
for key in result2:
|
||
result1[key] = result1.get(key, []) + result2[key]
|
||
return result1
|
||
|
||
|
||
# 获取图片旋转角度
|
||
def get_image_rotation_angle(img):
|
||
angle = 0
|
||
model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
||
result = model.predict(input_data=img)
|
||
try:
|
||
angle = int(next(result)[0]["label_names"][0])
|
||
except Exception as e:
|
||
logging.error("获取图片旋转角度失败", exc_info=e)
|
||
return angle
|
||
|
||
|
||
# 获取图片旋转角度
|
||
def get_image_rotation_angles(img):
|
||
angles = ['0', '90']
|
||
model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
||
result = model.predict(input_data=img)
|
||
try:
|
||
angles = next(result)[0]["label_names"]
|
||
except Exception as e:
|
||
logging.error("获取图片旋转角度失败", exc_info=e)
|
||
return angles
|
||
|
||
|
||
# 旋转图片
|
||
def rotate_image(img, angle):
|
||
if angle == 0:
|
||
return img
|
||
height, width, _ = img.shape
|
||
if angle == 180:
|
||
new_width = width
|
||
new_height = height
|
||
else:
|
||
new_width = height
|
||
new_height = width
|
||
# 绕图像的中心旋转
|
||
# 参数:旋转中心 旋转度数 scale
|
||
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)
|
||
# 旋转后平移
|
||
matrix[0, 2] += (new_width - width) / 2
|
||
matrix[1, 2] += (new_height - height) / 2
|
||
# 参数:原始图像 旋转参数 元素图像宽高
|
||
rotated = cv2.warpAffine(img, matrix, (new_width, new_height))
|
||
return rotated
|
||
|
||
|
||
# 获取图片OCR,并将其box转为两点矩形框
|
||
def get_ocr_layout(ocr, img_path):
|
||
def _get_box(old_box):
|
||
new_box = [
|
||
min(old_box[0][0], old_box[3][0]), # x1
|
||
min(old_box[0][1], old_box[1][1]), # y1
|
||
max(old_box[1][0], old_box[2][0]), # x2
|
||
max(old_box[2][1], old_box[3][1]), # y2
|
||
]
|
||
return new_box
|
||
|
||
def _normal_box(box_data):
|
||
# Ensure the height and width of bbox are greater than zero
|
||
if box_data[3] - box_data[1] < 0 or box_data[2] - box_data[0] < 0:
|
||
return False
|
||
return True
|
||
|
||
layout = []
|
||
ocr_result = ocr.ocr(img_path, cls=False)
|
||
ocr_result = ocr_result[0]
|
||
if not ocr_result:
|
||
return layout
|
||
for segment in ocr_result:
|
||
box = segment[0]
|
||
box = _get_box(box)
|
||
if not _normal_box(box):
|
||
continue
|
||
text = segment[1][0]
|
||
layout.append((box, text))
|
||
return layout
|
||
|
||
|
||
def request_ie_result(task_enum, phrec, identity):
|
||
url = task_enum.request_url()
|
||
payload = {"image_name": phrec.cfjaddress, "schema": task_enum.schema(), "pk_phhd": phrec.pk_phhd,
|
||
"pk_phrec": phrec.pk_phrec, "identity": identity}
|
||
response = requests.post(url, json=payload)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()["data"]
|
||
else:
|
||
raise Exception(f"请求信息抽取结果失败,状态码:{response.status_code}")
|
||
|
||
|
||
def ie_temp_image(ie, ocr, image):
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||
cv2.imwrite(temp_file.name, image)
|
||
|
||
ie_result = []
|
||
try:
|
||
layout = get_ocr_layout(ocr, temp_file.name)
|
||
if not layout:
|
||
# 无识别结果
|
||
ie_result = []
|
||
else:
|
||
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
|
||
except Exception as e:
|
||
logging.error("信息抽取时出错", exc_info=e)
|
||
finally:
|
||
try:
|
||
os.remove(temp_file.name)
|
||
except Exception as e:
|
||
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
|
||
return ie_result
|
||
|
||
|
||
# 关键信息提取
|
||
def information_extraction(task_enum, phrecs):
|
||
result = {}
|
||
# 同一批图的标识
|
||
identity = int(time.time())
|
||
for phrec in phrecs:
|
||
ie_result = request_ie_result(task_enum, phrec, identity)
|
||
result = merge_result(result, ie_result)
|
||
|
||
return result
|
||
|
||
|
||
# 从keys中获取准确率最高的value
|
||
def get_best_value_in_keys(source, keys):
|
||
# 最终结果
|
||
result = None
|
||
# 最大可能性
|
||
best_probability = 0
|
||
for key in keys:
|
||
values = source.get(key)
|
||
if values:
|
||
for value in values:
|
||
text = value.get("text")
|
||
probability = value.get("probability")
|
||
if text and probability > best_probability:
|
||
result = text
|
||
best_probability = probability
|
||
return result
|
||
|
||
|
||
# 从keys中获取所有value组成list
|
||
def get_values_of_keys(source, keys):
|
||
result = []
|
||
for key in keys:
|
||
value = source.get(key)
|
||
if value:
|
||
for v in value:
|
||
v = v.get("text")
|
||
if v:
|
||
result.append(v)
|
||
# 去重
|
||
return list(set(result))
|
||
|
||
|
||
def save_or_update_ie(table, pk_phhd, data):
|
||
data = {k: v for k, v in data.items() if v is not None and v != ""}
|
||
obj = table(**data)
|
||
session = MysqlSession()
|
||
db_data = session.query(table).filter_by(pk_phhd=pk_phhd).one_or_none()
|
||
now = get_default_datetime()
|
||
if db_data:
|
||
# 更新
|
||
db_data.update_time = now
|
||
for k, v in data.items():
|
||
setattr(db_data, k, v)
|
||
else:
|
||
# 新增
|
||
obj.create_time = now
|
||
obj.update_time = now
|
||
session.add(obj)
|
||
session.commit()
|
||
session.close()
|
||
|
||
|
||
def settlement_task(pk_phhd, settlement_list):
|
||
settlement_list_ie_result = information_extraction(TaskEnum.SETTLEMENT, settlement_list)
|
||
settlement_data = {
|
||
"pk_phhd": pk_phhd,
|
||
"name": handle_name(get_best_value_in_keys(settlement_list_ie_result, PATIENT_NAME)),
|
||
"admission_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_DATE)),
|
||
"discharge_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, DISCHARGE_DATE)),
|
||
"medical_expenses_str": handle_original_data(
|
||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES)),
|
||
"personal_cash_payment_str": handle_original_data(
|
||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_CASH_PAYMENT)),
|
||
"personal_account_payment_str": handle_original_data(
|
||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_ACCOUNT_PAYMENT)),
|
||
"personal_funded_amount_str": handle_original_data(
|
||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
|
||
"medical_insurance_type": handle_insurance_type(
|
||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE))
|
||
}
|
||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||
settlement_data["discharge_date"] = handle_date(settlement_data["discharge_date_str"])
|
||
settlement_data["medical_expenses"] = handle_decimal(settlement_data["medical_expenses_str"])
|
||
settlement_data["personal_cash_payment"] = handle_decimal(settlement_data["personal_cash_payment_str"])
|
||
settlement_data["personal_account_payment"] = handle_decimal(settlement_data["personal_account_payment_str"])
|
||
settlement_data["personal_funded_amount"] = handle_decimal(settlement_data["personal_funded_amount_str"])
|
||
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
|
||
|
||
|
||
def discharge_task(pk_phhd, discharge_record):
|
||
discharge_record_ie_result = information_extraction(TaskEnum.DISCHARGE, discharge_record)
|
||
discharge_data = {
|
||
"pk_phhd": pk_phhd,
|
||
"hospital": handle_hospital(get_best_value_in_keys(discharge_record_ie_result, HOSPITAL)),
|
||
"department": handle_department(get_best_value_in_keys(discharge_record_ie_result, DEPARTMENT)),
|
||
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
|
||
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
|
||
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
|
||
"doctor": handle_name(get_best_value_in_keys(discharge_record_ie_result, DOCTOR))
|
||
}
|
||
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
|
||
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
|
||
hospital_value = get_values_of_keys(discharge_record_ie_result, HOSPITAL)
|
||
if hospital_value:
|
||
session = MysqlSession()
|
||
yljg = session.query(BdYljg.pk_yljg, BdYljg.name) \
|
||
.filter(BdYljg.name.in_(hospital_value)).limit(1).one_or_none()
|
||
session.close()
|
||
if yljg:
|
||
discharge_data["pk_yljg"] = yljg.pk_yljg
|
||
discharge_data["hospital"] = yljg.name
|
||
department_value = get_values_of_keys(discharge_record_ie_result, DEPARTMENT)
|
||
if department_value:
|
||
department_values = []
|
||
for dept in department_value:
|
||
department_values += parse_department(dept)
|
||
department_values = list(set(department_values))
|
||
if department_values:
|
||
session = MysqlSession()
|
||
ylks = session.query(BdYlks.pk_ylks, BdYlks.name) \
|
||
.filter(BdYlks.name.in_(department_values)).limit(1).one_or_none()
|
||
session.close()
|
||
if ylks:
|
||
discharge_data["pk_ylks"] = ylks.pk_ylks
|
||
discharge_data["department"] = ylks.name
|
||
save_or_update_ie(ZxIeDischarge, pk_phhd, discharge_data)
|
||
|
||
|
||
def cost_task(pk_phhd, cost_list):
|
||
cost_list_ie_result = information_extraction(TaskEnum.COST, cost_list)
|
||
cost_data = {
|
||
"pk_phhd": pk_phhd,
|
||
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
|
||
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
|
||
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
|
||
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES))
|
||
}
|
||
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
|
||
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])
|
||
cost_data["medical_expenses"] = handle_decimal(cost_data["medical_expenses_str"])
|
||
save_or_update_ie(ZxIeCost, pk_phhd, cost_data)
|
||
|
||
|
||
def photo_review(pk_phhd):
|
||
settlement_list = []
|
||
discharge_record = []
|
||
cost_list = []
|
||
|
||
session = MysqlSession()
|
||
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress) \
|
||
.filter(ZxPhrec.pk_phhd == pk_phhd) \
|
||
.all()
|
||
session.close()
|
||
for phrec in phrecs:
|
||
if phrec.cRectype == "1":
|
||
settlement_list.append(phrec)
|
||
elif phrec.cRectype == "3":
|
||
discharge_record.append(phrec)
|
||
elif phrec.cRectype == "4":
|
||
cost_list.append(phrec)
|
||
|
||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||
executor.submit(settlement_task, pk_phhd, settlement_list)
|
||
executor.submit(discharge_task, pk_phhd, discharge_record)
|
||
executor.submit(cost_task, pk_phhd, cost_list)
|
||
|
||
|
||
def main():
|
||
# 持续检测新案子
|
||
while 1:
|
||
session = MysqlSession()
|
||
# 查询需要识别的案子
|
||
phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all()
|
||
session.close()
|
||
if phhds:
|
||
for phhd in phhds:
|
||
pk_phhd = phhd.pk_phhd
|
||
logging.info(f"开始识别:{pk_phhd}")
|
||
photo_review(pk_phhd)
|
||
|
||
# 识别完成更新标识
|
||
session = MysqlSession()
|
||
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8))
|
||
session.execute(update_flag)
|
||
session.commit()
|
||
session.close()
|
||
else:
|
||
# 没有查询到新案子,等待一段时间后再查
|
||
log = logging.getLogger()
|
||
log.info(f"暂未查询到新案子,等待{SLEEP_MINUTES}分钟...")
|
||
sleep(SLEEP_MINUTES * 60)
|