diff --git a/.gitignore b/.gitignore index d3e9a22..7a2d598 100644 --- a/.gitignore +++ b/.gitignore @@ -148,4 +148,5 @@ services/paddle_services/model *.log.*-*-* ### Tmp Files -/tmp_img \ No newline at end of file +/tmp_img +/test_img \ No newline at end of file diff --git a/auto_generator.py b/auto_generator.py deleted file mode 100644 index 897b0aa..0000000 --- a/auto_generator.py +++ /dev/null @@ -1,15 +0,0 @@ -# 自动生成数据库表和sqlalchemy对应的Model -import subprocess - -from db import DB_URL - -if __name__ == '__main__': - table = input("请输入表名:") - out_file = f"db/{table}.py" - command = f"sqlacodegen {DB_URL} --outfile={out_file} --tables={table}" - - try: - subprocess.run(command, shell=True, check=True) - print(f"{table}.py文件生成成功!请检查并复制到合适的文件中!") - except Exception as e: - print(f"生成{table}.py文件时发生错误: {e}") diff --git a/log/__init__.py b/log/__init__.py index 8a95796..f587dfd 100644 --- a/log/__init__.py +++ b/log/__init__.py @@ -7,9 +7,9 @@ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) HOSTNAME = socket.gethostname() # 检测日志文件的路径是否存在,不存在则创建 LOG_PATHS = [ - f"log/{HOSTNAME}/ucloud", - f"log/{HOSTNAME}/error", - f"log/{HOSTNAME}/qr", + os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'ucloud'), + os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'error'), + os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'qr'), ] for path in LOG_PATHS: if not os.path.exists(path): diff --git a/auto_email/__init__.py b/my_email/__init__.py similarity index 70% rename from auto_email/__init__.py rename to my_email/__init__.py index a13e8f0..788f2c7 100644 --- a/auto_email/__init__.py +++ b/my_email/__init__.py @@ -8,13 +8,13 @@ MAX_WAIT_TIME = 3 # 程序异常短信配置 ERROR_EMAIL_CONFIG = { # SMTP服务器地址 - "smtp_server": "smtp.163.com", + 'smtp_server': 'smtp.163.com', # 连接SMTP的端口 - "port": 994, + 'port': 994, # 发件人邮箱地址,请确保开启了SMTP邮件服务! - "sender": "EchoLiu618@163.com", + 'sender': 'EchoLiu618@163.com', # 授权码--用于登录第三方邮件客户端的专用密码,不是邮箱密码 - "authorization_code": "OKPQLIIVLVGRZYVH", + 'authorization_code': 'OKPQLIIVLVGRZYVH', # 收件人邮箱地址 - "receivers": ["1515783401@qq.com"], + 'receivers': ['1515783401@qq.com'], } diff --git a/auto_email/error_email.py b/my_email/error_email.py similarity index 85% rename from auto_email/error_email.py rename to my_email/error_email.py index 16dd93b..e3ad799 100644 --- a/auto_email/error_email.py +++ b/my_email/error_email.py @@ -5,18 +5,18 @@ from email.mime.text import MIMEText from tenacity import retry, stop_after_attempt, wait_random -from auto_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME from log import HOSTNAME +from my_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME @retry(stop=stop_after_attempt(TRY_TIMES), wait=wait_random(MIN_WAIT_TIME, MAX_WAIT_TIME), reraise=True, - after=lambda x: logging.warning("发送邮件失败!")) + after=lambda x: logging.warning('发送邮件失败!')) def send_email(email_config, massage): - smtp_server = email_config["smtp_server"] - port = email_config["port"] - sender = email_config["sender"] - authorization_code = email_config["authorization_code"] - receivers = email_config["receivers"] + smtp_server = email_config['smtp_server'] + port = email_config['port'] + sender = email_config['sender'] + authorization_code = email_config['authorization_code'] + receivers = email_config['receivers'] mail = smtplib.SMTP_SSL(smtp_server, port) # 连接SMTP服务 mail.login(sender, authorization_code) # 登录到SMTP服务 mail.sendmail(sender, receivers, massage.as_string()) # 发送邮件 @@ -34,13 +34,13 @@ def send_error_email(program_name, error_name, error_detail): """ # SMTP 服务器配置 - sender = ERROR_EMAIL_CONFIG["sender"] - receivers = ERROR_EMAIL_CONFIG["receivers"] + sender = ERROR_EMAIL_CONFIG['sender'] + receivers = ERROR_EMAIL_CONFIG['receivers'] # 获取程序出错的时间 - error_time = datetime.datetime.strftime(datetime.datetime.today(), "%Y-%m-%d %H:%M:%S:%f") + error_time = datetime.datetime.strftime(datetime.datetime.today(), '%Y-%m-%d %H:%M:%S:%f') # 邮件内容 - subject = f"【程序异常提醒】{program_name}({HOSTNAME}) {error_time}" # 邮件的标题 + subject = f'【程序异常提醒】{program_name}({HOSTNAME}) {error_time}' # 邮件的标题 content = f'''

程序运行异常通知

diff --git a/photo_mask.py b/photo_mask.py index 317b8ad..fa186fd 100644 --- a/photo_mask.py +++ b/photo_mask.py @@ -5,35 +5,35 @@ from time import sleep from sqlalchemy import update -from auto_email.error_email import send_error_email +from my_email.error_email import send_error_email from db import MysqlSession from db.mysql import ZxPhhd from log import LOGGING_CONFIG from photo_mask import auto_photo_mask, SEND_ERROR_EMAIL if __name__ == '__main__': - program_name = "照片审核自动涂抹脚本" + program_name = '照片审核自动涂抹脚本' logging.config.dictConfig(LOGGING_CONFIG) parser = argparse.ArgumentParser() - parser.add_argument("--clean", default=False, type=bool, help="是否将涂抹中的案子改为待涂抹状态") + parser.add_argument('--clean', default=False, type=bool, help='是否将涂抹中的案子改为待涂抹状态') args = parser.parse_args() if args.clean: # 主要用于启动时,清除仍在涂抹中的案子 session = MysqlSession() - update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == "2").values(paint_flag="1")) + update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == '2').values(paint_flag='1')) session.execute(update_flag) session.commit() session.close() - logging.info("已释放残余的涂抹案子!") + logging.info('已释放残余的涂抹案子!') else: sleep(5) try: - logging.info(f"【{program_name}】开始运行") + logging.info(f'【{program_name}】开始运行') auto_photo_mask.main() except Exception as e: - error_logger = logging.getLogger("error") + error_logger = logging.getLogger('error') error_logger.error(traceback.format_exc()) if SEND_ERROR_EMAIL: send_error_email(program_name, repr(e), traceback.format_exc()) diff --git a/photo_review.py b/photo_review.py index 1866edb..235644f 100644 --- a/photo_review.py +++ b/photo_review.py @@ -5,7 +5,7 @@ from time import sleep from sqlalchemy import update -from auto_email.error_email import send_error_email +from my_email.error_email import send_error_email from db import MysqlSession from db.mysql import ZxPhhd from log import LOGGING_CONFIG diff --git a/services/__init__.py b/services/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/services/paddle_services/log/__init__.py b/services/paddle_services/log/__init__.py index d2e9472..3b517cb 100644 --- a/services/paddle_services/log/__init__.py +++ b/services/paddle_services/log/__init__.py @@ -5,7 +5,7 @@ import socket HOSTNAME = socket.gethostname() # 检测日志文件的路径是否存在,不存在则创建 LOG_PATHS = [ - f"log/{HOSTNAME}/error", + f'log/{HOSTNAME}/error', ] for path in LOG_PATHS: if not os.path.exists(path): diff --git a/services/paddle_services/utils.py b/services/paddle_services/utils.py index 6e646cb..27c921a 100644 --- a/services/paddle_services/utils.py +++ b/services/paddle_services/utils.py @@ -14,7 +14,7 @@ def process_request(func): result = func(*args, **kwargs) return jsonify(result), 200 except Exception as e: - logging.getLogger("error").error(e, exc_info=e) + logging.getLogger('error').error(e, exc_info=e) return jsonify({'error': str(e)}), 500 return wrapper diff --git a/tool/auto_generator.py b/tool/auto_generator.py new file mode 100644 index 0000000..6c84164 --- /dev/null +++ b/tool/auto_generator.py @@ -0,0 +1,19 @@ +""" +自动生成数据库表和sqlalchemy对应的Model +""" +import os.path +import subprocess + +from db import DB_URL +from log import PROJECT_ROOT + +if __name__ == '__main__': + table = input('请输入表名:') + out_file = os.path.join(PROJECT_ROOT, 'db', f'{table}.py') + command = f'sqlacodegen {DB_URL} --outfile={out_file} --tables={table}' + + try: + subprocess.run(command, shell=True, check=True) + print(f'{table}.py文件生成成功!请检查并复制到合适的文件中!') + except Exception as e: + print(f'生成{table}.py文件时发生错误: {e}') diff --git a/ucloud/__init__.py b/ucloud/__init__.py index f3d0d28..8ee02dc 100644 --- a/ucloud/__init__.py +++ b/ucloud/__init__.py @@ -1,15 +1,15 @@ from ufile import config # 公钥 -PUBLIC_KEY = "4Z7QYI7qml36QRjcCjKrls7aHl1R6H6uq" +PUBLIC_KEY = '4Z7QYI7qml36QRjcCjKrls7aHl1R6H6uq' # 私钥 -PRIVATE_KEY = "FIdW1Kev1Ge3K7GHXzSLyGG1wTnaG6LE9BxmIVubcCaG" +PRIVATE_KEY = 'FIdW1Kev1Ge3K7GHXzSLyGG1wTnaG6LE9BxmIVubcCaG' # 桶 -BUCKET = "drg100" +BUCKET = 'drg100' # 上传后缀 -UPLOAD_SUFFIX = ".cn-sh2.ufileos.com" +UPLOAD_SUFFIX = '.cn-sh2.ufileos.com' # 下载后缀 -DOWNLOAD_SUFFIX = ".cn-sh2.ufileos.com" +DOWNLOAD_SUFFIX = '.cn-sh2.ufileos.com' # 私空间文件地址过期时间(秒) PRIVATE_EXPIRES = 3600 diff --git a/ucloud/ufile.py b/ucloud/ufile.py index 391bda0..e8cede0 100644 --- a/ucloud/ufile.py +++ b/ucloud/ufile.py @@ -17,10 +17,10 @@ def get_private_url(key, bucket=BUCKET): # 判断文件是否存在 _, resp = UFILE_HANDLER.head_file(bucket, key) if resp.status_code == -1: - UCLOUD_LOGGER.warning(f"查询({key})时uCloud连接失败!") - raise ConnectionError("uCloud连接失败") + UCLOUD_LOGGER.warning(f'查询({key})时uCloud连接失败!') + raise ConnectionError('uCloud连接失败') if resp.status_code != 200: - UCLOUD_LOGGER.warning(f"[{bucket}]中未找到({key})! status: {resp.status_code} error: {resp.error}") + UCLOUD_LOGGER.warning(f'[{bucket}]中未找到({key})! status: {resp.status_code} error: {resp.error}') return None # 获取公有空间下载url @@ -37,11 +37,11 @@ def copy_file(source_bucket, source_key, target_bucket, target_key): # 复制文件 _, resp = UFILE_HANDLER.copy(target_bucket, target_key, source_bucket, source_key) if resp.status_code == -1: - UCLOUD_LOGGER.warning(f"复制({source_key})时uCloud连接失败!") + UCLOUD_LOGGER.warning(f'复制({source_key})时uCloud连接失败!') return False if resp.status_code != 200: UCLOUD_LOGGER.warning( - f"将({source_key})从[{source_bucket}]拷贝到[{target_bucket}]失败! status: {resp.status_code} error: {resp.error}" + f'将({source_key})从[{source_bucket}]拷贝到[{target_bucket}]失败! status: {resp.status_code} error: {resp.error}' ) return False return True @@ -53,9 +53,9 @@ def upload_file(key, file_path, bucket=BUCKET): # 普通上传文件至云空间 _, resp = UFILE_HANDLER.putfile(bucket, key, file_path, header=None) if resp.status_code == -1: - UCLOUD_LOGGER.warning(f"上传({key})时uCloud连接失败!") + UCLOUD_LOGGER.warning(f'上传({key})时uCloud连接失败!') return False if resp.status_code != 200: - UCLOUD_LOGGER.warning(f"上传({key})失败! status: {resp.status_code} error: {resp.error}") + UCLOUD_LOGGER.warning(f'上传({key})失败! status: {resp.status_code} error: {resp.error}') return False return True diff --git a/util/common_util.py b/util/common_util.py index 091b123..7719d61 100644 --- a/util/common_util.py +++ b/util/common_util.py @@ -62,9 +62,9 @@ def delete_temp_file(temp_files): for file in temp_files: try: os.remove(file) - logging.info(f"临时文件 {file} 已删除") + logging.info(f'临时文件 {file} 已删除') except Exception as e: - logging.warning(f"删除临时文件 {file} 时出错: {e}") + logging.warning(f'删除临时文件 {file} 时出错: {e}') def zoom_rectangle(rectangle, ratio): @@ -83,40 +83,40 @@ def zoom_rectangle(rectangle, ratio): def chinese_to_money_unit(chinese): - if chinese in ["拾", "十"]: + if chinese in ['拾', '十']: return 10, False - elif chinese in ["佰", "百"]: + elif chinese in ['佰', '百']: return 100, False - elif chinese in ["仟", "千"]: + elif chinese in ['仟', '千']: return 1000, False - elif chinese == "万": + elif chinese == '万': return 10000, True - elif chinese == "亿": + elif chinese == '亿': return 100000000, True else: return None, False def chinese_char_to_number(chinese): - if chinese == "零": + if chinese == '零': return 0 - elif chinese in ["一", "壹"]: + elif chinese in ['一', '壹']: return 1 - elif chinese in ["二", "贰"]: + elif chinese in ['二', '贰']: return 2 - elif chinese in ["三", "叁"]: + elif chinese in ['三', '叁']: return 3 - elif chinese in ["四", "肆"]: + elif chinese in ['四', '肆']: return 4 - elif chinese in ["五", "伍"]: + elif chinese in ['五', '伍']: return 5 - elif chinese in ["六", "陆"]: + elif chinese in ['六', '陆']: return 6 - elif chinese in ["七", "柒"]: + elif chinese in ['七', '柒']: return 7 - elif chinese in ["八", "捌"]: + elif chinese in ['八', '捌']: return 8 - elif chinese in ["九", "玖"]: + elif chinese in ['九', '玖']: return 9 else: return -1 @@ -137,12 +137,12 @@ def chinese_to_number(chinese): section += number * (unit[0] / 10) unit = [None, False] elif number > 0: - raise ValueError(f"{chinese} has bad number '{chinese[i - 1]}{c}' at: {i}") + raise ValueError(f"'{chinese} has bad number '{chinese[i - 1]}{c}' at: {i}'") number = num else: unit = chinese_to_money_unit(c) if unit[0] is None: - raise ValueError(f"{chinese} has unknown unit '{c}' at: {i}") + raise ValueError(f"'{chinese} has unknown unit '{c}' at: {i}'") if unit[1]: section = (section + number) * unit[0] result += section @@ -163,14 +163,14 @@ def chinese_to_number(chinese): def chinese_money_to_number(chinese_money_amount): if string_util.blank(chinese_money_amount): return None - yi = chinese_money_amount.find("元") + yi = chinese_money_amount.find('元') if yi == -1: - yi = chinese_money_amount.find("圆") - ji = chinese_money_amount.find("角") - fi = chinese_money_amount.find("分") + yi = chinese_money_amount.find('圆') + ji = chinese_money_amount.find('角') + fi = chinese_money_amount.find('分') if yi == -1 and ji == -1 and fi == -1: - raise ValueError(f"无法解析: {chinese_money_amount}") + raise ValueError(f'无法解析: {chinese_money_amount}') y_str = None if yi > 0: diff --git a/util/data_util.py b/util/data_util.py index 94e80a8..b1e7608 100644 --- a/util/data_util.py +++ b/util/data_util.py @@ -8,20 +8,20 @@ from util import common_util # 处理金额类数据 def handle_decimal(string): if not string: - return "" + return '' string = re.sub(r'[^0-9.]', '', string) if not string: - return "" - if "." not in string: + return '' + if '.' not in string: if len(string) > 2: - result = string[:-2] + "." + string[-2:] + result = string[:-2] + '.' + string[-2:] else: result = string else: front, back = string.rsplit('.', 1) - front = front.replace(".", "") + front = front.replace('.', '') if back: - back = "." + back[:2] + back = '.' + back[:2] result = front + back return result[:16] @@ -32,7 +32,7 @@ def parse_money(capital_num, num): money = common_util.chinese_money_to_number(capital_num) return capital_num, money except Exception as e: - logging.warning("大写金额解析失败", exc_info=e) + logging.warning('大写金额解析失败', exc_info=e) return num, handle_decimal(num) @@ -40,17 +40,17 @@ def parse_money(capital_num, num): # 处理日期类数据 def handle_date(string): if not string: - return "" + return '' - string = string.replace("年", "-").replace("月", "-").replace("日", "").replace("/", "-").replace(".", "-") + string = string.replace('年', '-').replace('月', '-').replace('日', '').replace('/', '-').replace('.', '-') string = re.sub(r'[^0-9-]', '', string) - string = string.strip("-") - if "-" in string: - dash_count = string.count("-") + string = string.strip('-') + if '-' in string: + dash_count = string.count('-') if dash_count > 2: - third_dash_index = string.find("-", string.find("-", string.find("-") + 1) + 1) + third_dash_index = string.find('-', string.find('-', string.find('-') + 1) + 1) string = string[:third_dash_index] - day = string[string.rindex("-") + 1:] + day = string[string.rindex('-') + 1:] if len(day) > 2: string = string[:2 - len(day)] else: @@ -58,7 +58,7 @@ def handle_date(string): string = string[:8] if len(string) < 6: - return "" + return '' # 定义可能的日期格式 formats = [ @@ -78,23 +78,23 @@ def handle_date(string): date = datetime.strptime(string, fmt) # 限定日期的年份范围 if 2000 < date.year < 2100: - return date.strftime("%Y-%m-%d") + return date.strftime('%Y-%m-%d') continue except ValueError: continue - return "" + return '' def handle_hospital(string): if not string: - return "" + return '' return string[:255] def handle_department(string): if not string: - return "" + return '' return string[:255] @@ -103,12 +103,12 @@ def parse_department(string): if not string: return result - string = string.replace(")", "").replace(")", "").replace("(", " ").replace("(", " ") # 去除括号 + string = string.replace(')', '').replace(')', '').replace('(', ' ').replace('(', ' ') # 去除括号 string = re.sub(r'[^⺀-鿿 ]', '', string) # 去除非汉字字符,除了空格 string = re.sub(r'[一二三四五六七八九十]', '', string) # 去除中文数字 - string = string.replace("病区", "").replace("病", "") # 去除常见的无意义词 - string = string.replace("科", " ") # 分离科室 - departments = string.strip().split(" ") + string = string.replace('病区', '').replace('病', '') # 去除常见的无意义词 + string = string.replace('科', ' ') # 分离科室 + departments = string.strip().split(' ') for department in departments: if department: result.append(department) @@ -118,33 +118,33 @@ def parse_department(string): # 处理姓名类数据 def handle_name(string): if not string: - return "" + return '' return re.sub(r'[^⺀-鿿·]', '', string)[:30] # 处理医保类型数据 def handle_insurance_type(string): if not string: - return "" - worker_insurance_keys = ["社保", "城保", "职", "退休"] - villager_insurance_keys = ["农保", "居民"] - migrant_worker_insurance_keys = ["农民工"] - no_insurance_keys = ["自费", "全费"] + return '' + worker_insurance_keys = ['社保', '城保', '职', '退休'] + villager_insurance_keys = ['农保', '居民'] + migrant_worker_insurance_keys = ['农民工'] + no_insurance_keys = ['自费', '全费'] if any(key in string for key in worker_insurance_keys): - return "职工医保" + return '职工医保' if any(key in string for key in villager_insurance_keys): - return "居民医保" + return '居民医保' if any(key in string for key in migrant_worker_insurance_keys): - return "农民工医保" + return '农民工医保' if any(key in string for key in no_insurance_keys): - return "无医保" - return "其他" + return '无医保' + return '其他' # 处理原始数据 def handle_original_data(string): if not string: - return "" + return '' # 防止过长存入数据库失败 return string[:255] @@ -152,7 +152,7 @@ def handle_original_data(string): # 处理id类数据 def handle_id(string): if not string: - return "" + return '' # 防止过长存入数据库失败 return string[:50] @@ -160,8 +160,8 @@ def handle_id(string): # 处理年龄类数据 def handle_age(string): if not string: - return "" - string = string.split("岁")[0] + return '' + string = string.split('岁')[0] num = re.sub(r'\D', '', string) return num[-3:] @@ -173,8 +173,8 @@ def parse_hospital(string): return result string = common_util.traditional_to_simple_chinese(string) - string_without_brackets = string.replace(")", "").replace(")", "").replace("(", " ").replace("(", " ") - string_without_company = string_without_brackets.replace("有限公司", "") - split_hospitals = string_without_company.replace("医院", "医院 ") - result += split_hospitals.strip().split(" ") + string_without_brackets = string.replace(')', '').replace(')', '').replace('(', ' ').replace('(', ' ') + string_without_company = string_without_brackets.replace('有限公司', '') + split_hospitals = string_without_company.replace('医院', '医院 ') + result += split_hospitals.strip().split(' ') return result diff --git a/util/html_util.py b/util/html_util.py index 613a3a6..dc0997f 100644 --- a/util/html_util.py +++ b/util/html_util.py @@ -15,7 +15,7 @@ def get_jsczt_id_base(url): if response.status_code != 200: raise Exception(f'请求江苏省财政票据失败!状态码: {response.status_code}') soup = BeautifulSoup(response.text, 'html.parser') - hidden_input = soup.find('input', {'name': "idBase"}) + hidden_input = soup.find('input', {'name': 'idBase'}) if hidden_input: # 获取隐藏字段的值 value = hidden_input.get('value') diff --git a/util/image_util.py b/util/image_util.py index 5a9d4bb..d86fb61 100644 --- a/util/image_util.py +++ b/util/image_util.py @@ -13,14 +13,14 @@ from log import PROJECT_ROOT @retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True, - after=lambda x: logging.warning("获取图片失败!")) + after=lambda x: logging.warning('获取图片失败!')) def read(image_path): """ 从网络或本地读取图片 :param image_path: 网络或本地路径 :return: NumPy数组形式的图片 """ - if image_path.startswith("http"): + if image_path.startswith('http'): # 发送HTTP请求并获取图像数据 resp = urllib.request.urlopen(image_path, timeout=60) # 将数据读取为字节流 @@ -76,35 +76,35 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3): for i in range(math.ceil(height / step)): offset = round(step * i) cropped_img = capture(image, [0, offset, width, offset + new_img_height]) - split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset}) + split_result.append({'img': cropped_img, 'x_offset': 0, 'y_offset': offset}) elif wh_ratio > ratio: # 横向过长 new_img_width = height * ratio step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分 for i in range(math.ceil(width / step)): offset = round(step * i) cropped_img = capture(image, [offset, 0, offset + new_img_width, width]) - split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0}) + split_result.append({'img': cropped_img, 'x_offset': offset, 'y_offset': 0}) else: - split_result.append({"img": image, "x_offset": 0, "y_offset": 0}) + split_result.append({'img': image, 'x_offset': 0, 'y_offset': 0}) return split_result def parse_rotation_angles(image): """ - 判断图片旋转角度,逆时针旋转该角度后为正。可能值["0", "90", "180", "270"] + 判断图片旋转角度,逆时针旋转该角度后为正。可能值['0', '90', '180', '270'] :param image: 图片NumPy数组或文件路径 :return: 最有可能的两个角度 """ angles = ['0', '90'] - model = PaddleClas(model_name="text_image_orientation") + model = PaddleClas(model_name='text_image_orientation') clas_result = model.predict(input_data=image) try: clas_result = next(clas_result)[0] - if clas_result["scores"][0] < 0.5: + if clas_result['scores'][0] < 0.5: return angles - angles = clas_result["label_names"] + angles = clas_result['label_names'] except Exception as e: - logging.error("获取图片旋转角度失败", exc_info=e) + logging.error('获取图片旋转角度失败', exc_info=e) return angles @@ -201,25 +201,25 @@ def expand_to_a4_size(image): if hw_ratio >= 1.42: exp_w = int(height / 1.414 - width) x_offset = int(exp_w / 2) - exp_img = numpy.zeros((height, x_offset, 3), dtype="uint8") + exp_img = numpy.zeros((height, x_offset, 3), dtype='uint8') exp_img.fill(255) image = numpy.hstack([exp_img, image, exp_img]) elif 1 <= hw_ratio <= 1.40: exp_h = int(width * 1.414 - height) y_offset = int(exp_h / 2) - exp_img = numpy.zeros((y_offset, width, 3), dtype="uint8") + exp_img = numpy.zeros((y_offset, width, 3), dtype='uint8') exp_img.fill(255) image = numpy.vstack([exp_img, image, exp_img]) elif 0.72 <= hw_ratio < 1: exp_w = int(height * 1.414 - width) x_offset = int(exp_w / 2) - exp_img = numpy.zeros((height, x_offset, 3), dtype="uint8") + exp_img = numpy.zeros((height, x_offset, 3), dtype='uint8') exp_img.fill(255) image = numpy.hstack([exp_img, image, exp_img]) elif hw_ratio <= 0.7: exp_h = int(width / 1.414 - height) y_offset = int(exp_h / 2) - exp_img = numpy.zeros((y_offset, width, 3), dtype="uint8") + exp_img = numpy.zeros((y_offset, width, 3), dtype='uint8') exp_img.fill(255) image = numpy.vstack([exp_img, image, exp_img]) return image, x_offset, y_offset diff --git a/visual_model_test/__init__.py b/visual_model_test/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/visual_model_test/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/visual_model_test/img/PH20240401000003_3_001938_2.jpg b/visual_model_test/img/PH20240401000003_3_001938_2.jpg deleted file mode 100644 index 3613f39..0000000 Binary files a/visual_model_test/img/PH20240401000003_3_001938_2.jpg and /dev/null differ diff --git a/visual_model_test/img/PH20240511000638_1_094306_1.jpg b/visual_model_test/img/PH20240511000638_1_094306_1.jpg deleted file mode 100644 index e9da016..0000000 Binary files a/visual_model_test/img/PH20240511000638_1_094306_1.jpg and /dev/null differ diff --git a/visual_model_test/img/PH20240511000648_4_094542_2.jpg b/visual_model_test/img/PH20240511000648_4_094542_2.jpg deleted file mode 100644 index acd6bf0..0000000 Binary files a/visual_model_test/img/PH20240511000648_4_094542_2.jpg and /dev/null differ diff --git a/visual_model_test/visual_model_test.py b/visual_model_test/visual_model_test.py deleted file mode 100644 index c33dbfc..0000000 --- a/visual_model_test/visual_model_test.py +++ /dev/null @@ -1,163 +0,0 @@ -# 可视化的模型对比测试 -import os -import re -import tempfile -import time -from pprint import pprint - -import cv2 -from paddlenlp import Taskflow -from paddlenlp.utils.doc_parser import DocParser -from paddleocr import PaddleOCR - -from ucloud import ufile -from util import image_util, common_util - - -def write_visual_result(image, angle=0, layout=None, result=None): - img = image.split("?")[0] - img = re.split(r'[\\/]', img)[-1] - img_name = "" - img_type = "jpg" - last_dot_index = img.rfind(".") - if last_dot_index != -1: - img_name = img[:last_dot_index] - img_type = img[last_dot_index + 1:] - - img_array = image_util.read(image) - if angle != 0: - img_array = image_util.rotate(img_array, angle) - with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: - cv2.imwrite(temp_file.name, img_array) - if layout: - print(layout) - DocParser.write_image_with_results( - temp_file.name, - layout=layout, - save_path="./img_result/" + img_name + "_layout." + img_type) - - if result: - print(result) - DocParser.write_image_with_results( - temp_file.name, - result=result, - save_path="./img_result/" + img_name + "_result." + img_type) - os.remove(temp_file.name) - - -def visual_model_test(model_type, test_img, task_path, schema): - if model_type == "ocr": - imgs = image_util.split(test_img) - layout = [] - temp_files_paths = [] - # doc_parser = DocParser(layout_analysis=False) - for img in imgs: - with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: - # angle = image_util.parse_rotation_angles(img["img"])[0] - angle = 0 - rotated_img = image_util.rotate(img["img"], angle) - rotated_img, offset_x, offset_y = image_util.expand_to_a4_size(rotated_img) - cv2.imwrite(temp_file.name, rotated_img) - - img["x_offset"] -= offset_x - img["y_offset"] -= offset_y - - temp_files_paths.append(temp_file.name) - parsed_doc = common_util.get_ocr_layout( - PaddleOCR(det_db_box_thresh=0.3, det_db_thresh=0.1, det_limit_side_len=1248, drop_score=0.3, - save_crop_res=False), - temp_file.name) - # parsed_doc = doc_parser.parse({"doc": temp_file.name})["layout"] - if img["x_offset"] or img["y_offset"]: - for p in parsed_doc: - box = p[0] - box[0] += img["x_offset"] - box[1] += img["y_offset"] - box[2] += img["x_offset"] - box[3] += img["y_offset"] - layout += parsed_doc - - write_visual_result(test_img, angle, layout=layout) - else: - docs = [] - split_result = image_util.split(test_img) - temp_files_paths = [] - for img in split_result: - with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: - angle = int(image_util.parse_rotation_angles(img["img"])[0]) - rotated_img = image_util.rotate(img["img"], angle) - cv2.imwrite(temp_file.name, rotated_img) - temp_files_paths.append(temp_file.name) - docs.append({"doc": temp_file.name}) - - my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, - layout_analysis=False) - my_results = my_ie(docs) - write_visual_result(test_img, angle, result=my_results[0]) - - # 使用完临时文件后,记得清理(删除)它们 - for path in temp_files_paths: - try: - os.remove(path) - print(f"临时文件 {path} 已删除") - except Exception as e: - print(f"删除临时文件 {path} 时出错: {e}") - - -def batch_test(test_imgs, task_path, schema): - docs = [] - for test_img in test_imgs: - docs.append({"doc": test_img}) - my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path, - layout_analysis=False, batch_size=16) - # 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}]) - my_results = my_ie(docs) - pprint(my_results) - - -def main(model_type, pic_name=None): - # 开始时间 - start_time = time.time() - - if model_type == "ocr": - task_path = None - test_img_path = ufile.get_private_url(pic_name, - "drg103") if pic_name else "../test_img/PH20240725004467_3_185708_1.jpg" - schema = None - elif model_type == "settlement": - task_path = "../services/paddle_services/model/settlement_list_model" - test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg" - schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额", - "医保类型", "住院号", "医保结算单号码", "大写总额"] - elif model_type == "discharge": - task_path = "../services/paddle_services/model/discharge_record_model" - test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg" - schema = ["医院", "科室", "患者姓名", "入院日期", "出院日期", "主治医生", "住院号", "年龄"] - elif model_type == "cost": - task_path = "../services/paddle_services/model/cost_list_model" - test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg" - schema = ["患者姓名", "入院日期", "出院日期", "费用总额"] - elif model_type == "cost_detail": - task_path = "../services/paddle_services/model/cost_list_detail_model" - test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg" - schema = {"名称": ["类别", "规格", "单价", "数量", "金额"]} - else: - print("请输入正确的类型!") - return - visual_model_test(model_type, test_img_path, task_path, schema) - - # 结束时间 - end_time = time.time() - pprint(f"处理时长:{end_time - start_time}秒") - - -if __name__ == '__main__': - main("ocr") - # main("settlement") - # main("discharge") - # main("cost") - # main("cost_detail") - # write_visual_result("img/PH20240428000832_1_093844_2.jpg", layout=[([508.0975609756094, - # 659.7073170731707, - # 1000, - # 745.756097560976], 'lay', 'figure')])