Files
fcb_photo_review/photo_review/photo_review.py
2024-07-01 14:57:25 +08:00

418 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import math
import os
import sys
import tempfile
import time
import urllib.request
import cv2
import numpy as np
import paddleclas
from paddlenlp import Taskflow
from paddleocr import PaddleOCR
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from time import sleep
from sqlalchemy import update
from config.keys import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
SETTLEMENT_LIST_SCHEMA, DISCHARGE_RECORD_SCHEMA, COST_LIST_SCHEMA
from config.mysql import MysqlSession
from config.photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, LAYOUT_ANALYSIS
from photo_review.entity.bd_yljg import BdYljg
from photo_review.entity.bd_ylks import BdYlks
from photo_review.entity.zx_ie_cost import ZxIeCost
from photo_review.entity.zx_ie_discharge import ZxIeDischarge
from photo_review.entity.zx_ie_settlement import ZxIeSettlement
from photo_review.entity.zx_ocr import ZxOcr
from photo_review.entity.zx_phhd import ZxPhhd
from photo_review.entity.zx_phrec import ZxPhrec
from photo_review.util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
handle_insurance_type, handle_original_data, handle_hospital, handle_department
from photo_review.util.util import get_default_datetime
from ucloud import ucloud
# 获取图片
def open_image(img_path):
if img_path.startswith("http"):
# 发送HTTP请求并获取图像数据
resp = urllib.request.urlopen(img_path)
# 将数据读取为字节流
image_data = resp.read()
# 将字节流转换为NumPy数组
image_np = np.frombuffer(image_data, np.uint8)
# 解码NumPy数组为OpenCV图像格式
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
else:
image = cv2.imread(img_path)
return image
# 分割大图片
def split_image(img_path, max_ratio=2.82, best_ration=1.41, overlap=0.05):
split_result = []
# 打开图片
img = open_image(img_path)
# 获取图片的宽度和高度
height, width = img.shape[:2]
# 计算宽高比
ratio = max(width, height) / min(width, height)
# 检查是否需要裁剪
if ratio > max_ratio:
# 确定裁剪的尺寸,保持长宽比,以较短边为基准
new_ratio = best_ration - overlap
if width < height:
# 高度是较长边
cropped_width = width * best_ration
for i in range(math.ceil(height / (width * new_ratio))):
offset = round(width * new_ratio * i)
# 参数形式为[y1:y2, x1:x2]
cropped_img = img[offset:round(offset + cropped_width), 0:width]
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, round(cropped_width - last_img.shape[0]), 0, 0,
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
# 宽度是较长边
cropped_height = height * best_ration
for i in range(math.ceil(width / (height * new_ratio))):
offset = round(height * new_ratio * i)
cropped_img = img[0:height, offset:round(offset + cropped_height)]
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
# 最后一次裁剪时不足的部分填充黑色
last_img = split_result[-1]["img"]
split_result[-1]["img"] = cv2.copyMakeBorder(last_img, 0, 0, 0, round(cropped_height - last_img.shape[1]),
cv2.BORDER_CONSTANT, value=(0, 0, 0))
else:
split_result.append({"img": img, "x_offset": 0, "y_offset": 0})
return split_result
# 合并信息抽取结果
def merge_result(result1, result2):
for key in result2:
result1[key] = result1.get(key, []) + result2[key]
return result1
# 获取图片旋转角度
def get_image_rotation_angle(img):
angle = 0
model = paddleclas.PaddleClas(model_name="text_image_orientation")
result = model.predict(input_data=img)
try:
angle = int(next(result)[0]["label_names"][0])
except Exception as e:
logging.error("获取图片旋转角度失败", exc_info=e)
return angle
# 获取图片旋转角度
def get_image_rotation_angles(img):
angles = ['0', '90']
model = paddleclas.PaddleClas(model_name="text_image_orientation")
result = model.predict(input_data=img)
try:
angles = next(result)[0]["label_names"]
except Exception as e:
logging.error("获取图片旋转角度失败", exc_info=e)
return angles
# 旋转图片
def rotate_image(img, angle):
if angle == 0:
return img
height, width, _ = img.shape
if angle == 180:
new_width = width
new_height = height
else:
new_width = height
new_height = width
# 绕图像的中心旋转
# 参数:旋转中心 旋转度数 scale
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)
# 旋转后平移
matrix[0, 2] += (new_width - width) / 2
matrix[1, 2] += (new_height - height) / 2
# 参数:原始图像 旋转参数 元素图像宽高
rotated = cv2.warpAffine(img, matrix, (new_width, new_height))
return rotated
# 获取图片OCR并将其box转为两点矩形框
def get_ocr_layout(ocr, img_path):
def _get_box(old_box):
new_box = [
min(old_box[0][0], old_box[3][0]), # x1
min(old_box[0][1], old_box[1][1]), # y1
max(old_box[1][0], old_box[2][0]), # x2
max(old_box[2][1], old_box[3][1]), # y2
]
return new_box
def _normal_box(box_data):
# Ensure the height and width of bbox are greater than zero
if box_data[3] - box_data[1] < 0 or box_data[2] - box_data[0] < 0:
return False
return True
layout = []
ocr_result = ocr.ocr(img_path, cls=False)
ocr_result = ocr_result[0]
if not ocr_result:
return layout
for segment in ocr_result:
box = segment[0]
box = _get_box(box)
if not _normal_box(box):
continue
text = segment[1][0]
layout.append((box, text))
return layout
def ie_temp_image(ie, ocr, image):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
ie_result = []
try:
layout = get_ocr_layout(ocr, temp_file.name)
if not layout:
# 无识别结果
ie_result = []
else:
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
except Exception as e:
logging.error("信息抽取时出错", exc_info=e)
finally:
try:
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
return ie_result
# 关键信息提取
def information_extraction(ie, phrecs):
result = {}
# 同一批图的标识
identity = int(time.time())
ocr = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
for phrec in phrecs:
pic_path = ucloud.get_private_url(phrec.cfjaddress)
if not pic_path:
continue
split_result = split_image(pic_path)
for img in split_result:
angles = get_image_rotation_angles(img["img"])
rotated_img = rotate_image(img["img"], int(angles[0]))
ie_results = [{"result": ie_temp_image(ie, ocr, rotated_img), "angle": angles[0]}]
if not ie_results[0] or len(ie_results[0]) < len(ie.kwargs.get("schema")):
rotated_img = rotate_image(img["img"], int(angles[1]))
ie_results.append({"result": ie_temp_image(ie, ocr, rotated_img), "angle": angles[1]})
now = get_default_datetime()
for ie_result in ie_results:
result_json = json.dumps(ie_result["result"], ensure_ascii=False)
if len(result_json) > 5000:
result_json = result_json[:5000]
session = MysqlSession()
zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress,
content=result_json, rotation_angle=ie_result["angle"], x_offset=img["x_offset"],
y_offset=img["y_offset"], create_time=now, update_time=now)
session.add(zx_ocr)
session.commit()
session.close()
result = merge_result(result, ie_result["result"])
return result
# 从keys中获取准确率最高的value
def get_best_value_in_keys(source, keys):
# 最终结果
result = None
# 最大可能性
best_probability = 0
for key in keys:
values = source.get(key)
if values:
for value in values:
text = value.get("text")
probability = value.get("probability")
if text and probability > best_probability:
result = text
best_probability = probability
return result
# 从keys中获取所有value组成list
def get_values_of_keys(source, keys):
result = []
for key in keys:
value = source.get(key)
if value:
for v in value:
v = v.get("text")
if v:
result.append(v)
# 去重
return list(set(result))
def save_or_update_ie(table, pk_phhd, data):
data = {k: v for k, v in data.items() if v is not None and v != ""}
obj = table(**data)
session = MysqlSession()
db_data = session.query(table).filter_by(pk_phhd=pk_phhd).one_or_none()
now = get_default_datetime()
if db_data:
# 更新
db_data.update_time = now
for k, v in data.items():
setattr(db_data, k, v)
else:
# 新增
obj.create_time = now
obj.update_time = now
session.add(obj)
session.commit()
session.close()
def photo_review(pk_phhd):
settlement_list = []
discharge_record = []
cost_list = []
session = MysqlSession()
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress) \
.filter(ZxPhrec.pk_phhd == pk_phhd) \
.all()
session.close()
for phrec in phrecs:
if phrec.cRectype == "1":
settlement_list.append(phrec)
elif phrec.cRectype == "3":
discharge_record.append(phrec)
elif phrec.cRectype == "4":
cost_list.append(phrec)
settlement_list_ie_result = information_extraction(
Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS), settlement_list)
settlement_data = {
"pk_phhd": pk_phhd,
"name": handle_name(get_best_value_in_keys(settlement_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, DISCHARGE_DATE)),
"medical_expenses_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES)),
"personal_cash_payment_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_CASH_PAYMENT)),
"personal_account_payment_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_ACCOUNT_PAYMENT)),
"personal_funded_amount_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
"medical_insurance_type": handle_insurance_type(
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE))
}
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
settlement_data["discharge_date"] = handle_date(settlement_data["discharge_date_str"])
settlement_data["medical_expenses"] = handle_decimal(settlement_data["medical_expenses_str"])
settlement_data["personal_cash_payment"] = handle_decimal(settlement_data["personal_cash_payment_str"])
settlement_data["personal_account_payment"] = handle_decimal(settlement_data["personal_account_payment_str"])
settlement_data["personal_funded_amount"] = handle_decimal(settlement_data["personal_funded_amount_str"])
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
discharge_record_ie_result = information_extraction(
Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS), discharge_record)
discharge_data = {
"pk_phhd": pk_phhd,
"hospital": handle_hospital(get_best_value_in_keys(discharge_record_ie_result, HOSPITAL)),
"department": handle_department(get_best_value_in_keys(discharge_record_ie_result, DEPARTMENT)),
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
"doctor": handle_name(get_best_value_in_keys(discharge_record_ie_result, DOCTOR))
}
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
hospital_value = get_values_of_keys(discharge_record_ie_result, HOSPITAL)
if hospital_value:
session = MysqlSession()
yljg = session.query(BdYljg.pk_yljg, BdYljg.name) \
.filter(BdYljg.name.in_(hospital_value)).limit(1).one_or_none()
session.close()
if yljg:
discharge_data["pk_yljg"] = yljg.pk_yljg
discharge_data["hospital"] = yljg.name
department_value = get_values_of_keys(discharge_record_ie_result, DEPARTMENT)
if department_value:
department_values = []
for dept in department_value:
department_values += parse_department(dept)
department_values = list(set(department_values))
if department_values:
session = MysqlSession()
ylks = session.query(BdYlks.pk_ylks, BdYlks.name) \
.filter(BdYlks.name.in_(department_values)).limit(1).one_or_none()
session.close()
if ylks:
discharge_data["pk_ylks"] = ylks.pk_ylks
discharge_data["department"] = ylks.name
save_or_update_ie(ZxIeDischarge, pk_phhd, discharge_data)
cost_list_ie_result = information_extraction(
Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base",
task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS), cost_list)
cost_data = {
"pk_phhd": pk_phhd,
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES))
}
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])
cost_data["medical_expenses"] = handle_decimal(cost_data["medical_expenses_str"])
save_or_update_ie(ZxIeCost, pk_phhd, cost_data)
def main():
# 持续检测新案子
while 1:
session = MysqlSession()
# 查询需要识别的案子
phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all()
session.close()
if phhds:
for phhd in phhds:
pk_phhd = phhd.pk_phhd
logging.info(f"开始识别:{pk_phhd}")
photo_review(pk_phhd)
# 识别完成更新标识
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(exsuccess_flag=8))
session.execute(update_flag)
session.commit()
session.close()
else:
# 没有查询到新案子,等待一段时间后再查
log = logging.getLogger()
log.info(f"暂未查询到新案子,等待{SLEEP_MINUTES}分钟...")
sleep(SLEEP_MINUTES * 60)