diff --git a/photo_review/photo_review.py b/photo_review/photo_review.py index 4534634..bec0b21 100644 --- a/photo_review/photo_review.py +++ b/photo_review/photo_review.py @@ -112,6 +112,18 @@ def get_image_rotation_angle(img): return angle +# 获取图片旋转角度 +def get_image_rotation_angles(img): + angles = ['0', '90'] + model = paddleclas.PaddleClas(model_name="text_image_orientation") + result = model.predict(input_data=img) + try: + angles = next(result)[0]["label_names"] + except Exception as e: + logging.error("获取图片旋转角度失败", exc_info=e) + return angles + + # 旋转图片 def rotate_image(img, angle): if angle == 0: @@ -152,7 +164,7 @@ def get_ocr_layout(ocr, img_path): return True layout = [] - ocr_result = ocr.ocr(img_path) + ocr_result = ocr.ocr(img_path, cls=False) ocr_result = ocr_result[0] if not ocr_result: return layout @@ -193,7 +205,7 @@ def information_extraction(ie, phrecs): result = {} # 同一批图的标识 identity = int(time.time()) - ocr = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False) + ocr = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False) for phrec in phrecs: pic_path = ucloud.get_private_url(phrec.cfjaddress) if not pic_path: @@ -201,41 +213,27 @@ def information_extraction(ie, phrecs): split_result = split_image(pic_path) for img in split_result: - ie_result1 = ie_temp_image(ie, ocr, img["img"]) - if not ie_result1 or len(ie_result1) < len(ie.kwargs.get("schema")): - rotated_img = rotate_image(img["img"], 90) - ie_result2 = ie_temp_image(ie, ocr, rotated_img) - if not (ie_result1 or ie_result2): - continue - elif not ie_result1: - ie_result = ie_result2 - angle = 90 - elif not ie_result2: - ie_result = ie_result1 - angle = 0 - elif len(ie_result2) > len(ie_result1): - ie_result = ie_result2 - angle = 90 - else: - ie_result = ie_result1 - angle = 0 - else: - ie_result = ie_result1 - angle = 0 + angles = get_image_rotation_angles(img["img"]) + rotated_img = rotate_image(img["img"], int(angles[0])) + ie_results = [{"result": ie_temp_image(ie, ocr, rotated_img), "angle": angles[0]}] + if not ie_results[0] or len(ie_results[0]) < len(ie.kwargs.get("schema")): + rotated_img = rotate_image(img["img"], int(angles[1])) + ie_results.append({"result": ie_temp_image(ie, ocr, rotated_img), "angle": angles[1]}) now = get_default_datetime() - result_json = json.dumps(ie_result, ensure_ascii=False) - if len(result_json) > 5000: - result_json = result_json[:5000] - session = MysqlSession() - zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress, - content=result_json, rotation_angle=angle, x_offset=img["x_offset"], - y_offset=img["y_offset"], create_time=now, update_time=now) - session.add(zx_ocr) - session.commit() - session.close() + for ie_result in ie_results: + result_json = json.dumps(ie_result["result"], ensure_ascii=False) + if len(result_json) > 5000: + result_json = result_json[:5000] + session = MysqlSession() + zx_ocr = ZxOcr(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity, cfjaddress=phrec.cfjaddress, + content=result_json, rotation_angle=ie_result["angle"], x_offset=img["x_offset"], + y_offset=img["y_offset"], create_time=now, update_time=now) + session.add(zx_ocr) + session.commit() + session.close() - result = merge_result(result, ie_result) + result = merge_result(result, ie_result["result"]) return result @@ -292,7 +290,7 @@ def save_or_update_ie(table, pk_phhd, data): session.close() -def photo_review(pk_phhd): +def photo_review(pk_phhd, task_flows): settlement_list = [] discharge_record = [] cost_list = [] @@ -310,9 +308,7 @@ def photo_review(pk_phhd): elif phrec.cRectype == "4": cost_list.append(phrec) - settlement_list_ie_result = information_extraction( - Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base", - task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS), settlement_list) + settlement_list_ie_result = information_extraction(task_flows[0], settlement_list) settlement_data = { "pk_phhd": pk_phhd, "name": handle_name(get_best_value_in_keys(settlement_list_ie_result, PATIENT_NAME)), @@ -338,9 +334,7 @@ def photo_review(pk_phhd): settlement_data["personal_funded_amount"] = handle_decimal(settlement_data["personal_funded_amount_str"]) save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data) - discharge_record_ie_result = information_extraction( - Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base", - task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS), discharge_record) + discharge_record_ie_result = information_extraction(task_flows[1], discharge_record) discharge_data = { "pk_phhd": pk_phhd, "hospital": handle_hospital(get_best_value_in_keys(discharge_record_ie_result, HOSPITAL)), @@ -377,9 +371,7 @@ def photo_review(pk_phhd): discharge_data["department"] = ylks.name save_or_update_ie(ZxIeDischarge, pk_phhd, discharge_data) - cost_list_ie_result = information_extraction( - Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", - task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS), cost_list) + cost_list_ie_result = information_extraction(task_flows[2], cost_list) cost_data = { "pk_phhd": pk_phhd, "name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)), @@ -401,10 +393,18 @@ def main(): phhds = session.query(ZxPhhd.pk_phhd).filter(ZxPhhd.exsuccess_flag == '1').limit(PHHD_BATCH_SIZE).all() session.close() if phhds: + ie_task_flows = [ + Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base", + task_path="config/model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS), + Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base", + task_path="config/model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS), + Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", + task_path="config/model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS), + ] for phhd in phhds: pk_phhd = phhd.pk_phhd logging.info(f"开始识别:{pk_phhd}") - photo_review(pk_phhd) + photo_review(pk_phhd, ie_task_flows) # 识别完成更新标识 session = MysqlSession()