统一引号格式,优化架构排布
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -148,4 +148,5 @@ services/paddle_services/model
|
|||||||
*.log.*-*-*
|
*.log.*-*-*
|
||||||
|
|
||||||
### Tmp Files
|
### Tmp Files
|
||||||
/tmp_img
|
/tmp_img
|
||||||
|
/test_img
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# 自动生成数据库表和sqlalchemy对应的Model
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from db import DB_URL
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
table = input("请输入表名:")
|
|
||||||
out_file = f"db/{table}.py"
|
|
||||||
command = f"sqlacodegen {DB_URL} --outfile={out_file} --tables={table}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.run(command, shell=True, check=True)
|
|
||||||
print(f"{table}.py文件生成成功!请检查并复制到合适的文件中!")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"生成{table}.py文件时发生错误: {e}")
|
|
||||||
@@ -7,9 +7,9 @@ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
HOSTNAME = socket.gethostname()
|
HOSTNAME = socket.gethostname()
|
||||||
# 检测日志文件的路径是否存在,不存在则创建
|
# 检测日志文件的路径是否存在,不存在则创建
|
||||||
LOG_PATHS = [
|
LOG_PATHS = [
|
||||||
f"log/{HOSTNAME}/ucloud",
|
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'ucloud'),
|
||||||
f"log/{HOSTNAME}/error",
|
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'error'),
|
||||||
f"log/{HOSTNAME}/qr",
|
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'qr'),
|
||||||
]
|
]
|
||||||
for path in LOG_PATHS:
|
for path in LOG_PATHS:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ MAX_WAIT_TIME = 3
|
|||||||
# 程序异常短信配置
|
# 程序异常短信配置
|
||||||
ERROR_EMAIL_CONFIG = {
|
ERROR_EMAIL_CONFIG = {
|
||||||
# SMTP服务器地址
|
# SMTP服务器地址
|
||||||
"smtp_server": "smtp.163.com",
|
'smtp_server': 'smtp.163.com',
|
||||||
# 连接SMTP的端口
|
# 连接SMTP的端口
|
||||||
"port": 994,
|
'port': 994,
|
||||||
# 发件人邮箱地址,请确保开启了SMTP邮件服务!
|
# 发件人邮箱地址,请确保开启了SMTP邮件服务!
|
||||||
"sender": "EchoLiu618@163.com",
|
'sender': 'EchoLiu618@163.com',
|
||||||
# 授权码--用于登录第三方邮件客户端的专用密码,不是邮箱密码
|
# 授权码--用于登录第三方邮件客户端的专用密码,不是邮箱密码
|
||||||
"authorization_code": "OKPQLIIVLVGRZYVH",
|
'authorization_code': 'OKPQLIIVLVGRZYVH',
|
||||||
# 收件人邮箱地址
|
# 收件人邮箱地址
|
||||||
"receivers": ["1515783401@qq.com"],
|
'receivers': ['1515783401@qq.com'],
|
||||||
}
|
}
|
||||||
@@ -5,18 +5,18 @@ from email.mime.text import MIMEText
|
|||||||
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_random
|
from tenacity import retry, stop_after_attempt, wait_random
|
||||||
|
|
||||||
from auto_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
|
|
||||||
from log import HOSTNAME
|
from log import HOSTNAME
|
||||||
|
from my_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
|
||||||
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(TRY_TIMES), wait=wait_random(MIN_WAIT_TIME, MAX_WAIT_TIME), reraise=True,
|
@retry(stop=stop_after_attempt(TRY_TIMES), wait=wait_random(MIN_WAIT_TIME, MAX_WAIT_TIME), reraise=True,
|
||||||
after=lambda x: logging.warning("发送邮件失败!"))
|
after=lambda x: logging.warning('发送邮件失败!'))
|
||||||
def send_email(email_config, massage):
|
def send_email(email_config, massage):
|
||||||
smtp_server = email_config["smtp_server"]
|
smtp_server = email_config['smtp_server']
|
||||||
port = email_config["port"]
|
port = email_config['port']
|
||||||
sender = email_config["sender"]
|
sender = email_config['sender']
|
||||||
authorization_code = email_config["authorization_code"]
|
authorization_code = email_config['authorization_code']
|
||||||
receivers = email_config["receivers"]
|
receivers = email_config['receivers']
|
||||||
mail = smtplib.SMTP_SSL(smtp_server, port) # 连接SMTP服务
|
mail = smtplib.SMTP_SSL(smtp_server, port) # 连接SMTP服务
|
||||||
mail.login(sender, authorization_code) # 登录到SMTP服务
|
mail.login(sender, authorization_code) # 登录到SMTP服务
|
||||||
mail.sendmail(sender, receivers, massage.as_string()) # 发送邮件
|
mail.sendmail(sender, receivers, massage.as_string()) # 发送邮件
|
||||||
@@ -34,13 +34,13 @@ def send_error_email(program_name, error_name, error_detail):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# SMTP 服务器配置
|
# SMTP 服务器配置
|
||||||
sender = ERROR_EMAIL_CONFIG["sender"]
|
sender = ERROR_EMAIL_CONFIG['sender']
|
||||||
receivers = ERROR_EMAIL_CONFIG["receivers"]
|
receivers = ERROR_EMAIL_CONFIG['receivers']
|
||||||
|
|
||||||
# 获取程序出错的时间
|
# 获取程序出错的时间
|
||||||
error_time = datetime.datetime.strftime(datetime.datetime.today(), "%Y-%m-%d %H:%M:%S:%f")
|
error_time = datetime.datetime.strftime(datetime.datetime.today(), '%Y-%m-%d %H:%M:%S:%f')
|
||||||
# 邮件内容
|
# 邮件内容
|
||||||
subject = f"【程序异常提醒】{program_name}({HOSTNAME}) {error_time}" # 邮件的标题
|
subject = f'【程序异常提醒】{program_name}({HOSTNAME}) {error_time}' # 邮件的标题
|
||||||
content = f'''<div class="emailcontent" style="width:100%;max-width:720px;text-align:left;margin:0 auto;padding-top:80px;padding-bottom:20px">
|
content = f'''<div class="emailcontent" style="width:100%;max-width:720px;text-align:left;margin:0 auto;padding-top:80px;padding-bottom:20px">
|
||||||
<div class="emailtitle">
|
<div class="emailtitle">
|
||||||
<h1 style="color:#fff;background:#51a0e3;line-height:70px;font-size:24px;font-weight:400;padding-left:40px;margin:0">程序运行异常通知</h1>
|
<h1 style="color:#fff;background:#51a0e3;line-height:70px;font-size:24px;font-weight:400;padding-left:40px;margin:0">程序运行异常通知</h1>
|
||||||
@@ -5,35 +5,35 @@ from time import sleep
|
|||||||
|
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from auto_email.error_email import send_error_email
|
from my_email.error_email import send_error_email
|
||||||
from db import MysqlSession
|
from db import MysqlSession
|
||||||
from db.mysql import ZxPhhd
|
from db.mysql import ZxPhhd
|
||||||
from log import LOGGING_CONFIG
|
from log import LOGGING_CONFIG
|
||||||
from photo_mask import auto_photo_mask, SEND_ERROR_EMAIL
|
from photo_mask import auto_photo_mask, SEND_ERROR_EMAIL
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
program_name = "照片审核自动涂抹脚本"
|
program_name = '照片审核自动涂抹脚本'
|
||||||
logging.config.dictConfig(LOGGING_CONFIG)
|
logging.config.dictConfig(LOGGING_CONFIG)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--clean", default=False, type=bool, help="是否将涂抹中的案子改为待涂抹状态")
|
parser.add_argument('--clean', default=False, type=bool, help='是否将涂抹中的案子改为待涂抹状态')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.clean:
|
if args.clean:
|
||||||
# 主要用于启动时,清除仍在涂抹中的案子
|
# 主要用于启动时,清除仍在涂抹中的案子
|
||||||
session = MysqlSession()
|
session = MysqlSession()
|
||||||
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == "2").values(paint_flag="1"))
|
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == '2').values(paint_flag='1'))
|
||||||
session.execute(update_flag)
|
session.execute(update_flag)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
logging.info("已释放残余的涂抹案子!")
|
logging.info('已释放残余的涂抹案子!')
|
||||||
else:
|
else:
|
||||||
sleep(5)
|
sleep(5)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info(f"【{program_name}】开始运行")
|
logging.info(f'【{program_name}】开始运行')
|
||||||
auto_photo_mask.main()
|
auto_photo_mask.main()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_logger = logging.getLogger("error")
|
error_logger = logging.getLogger('error')
|
||||||
error_logger.error(traceback.format_exc())
|
error_logger.error(traceback.format_exc())
|
||||||
if SEND_ERROR_EMAIL:
|
if SEND_ERROR_EMAIL:
|
||||||
send_error_email(program_name, repr(e), traceback.format_exc())
|
send_error_email(program_name, repr(e), traceback.format_exc())
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from time import sleep
|
|||||||
|
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from auto_email.error_email import send_error_email
|
from my_email.error_email import send_error_email
|
||||||
from db import MysqlSession
|
from db import MysqlSession
|
||||||
from db.mysql import ZxPhhd
|
from db.mysql import ZxPhhd
|
||||||
from log import LOGGING_CONFIG
|
from log import LOGGING_CONFIG
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import socket
|
|||||||
HOSTNAME = socket.gethostname()
|
HOSTNAME = socket.gethostname()
|
||||||
# 检测日志文件的路径是否存在,不存在则创建
|
# 检测日志文件的路径是否存在,不存在则创建
|
||||||
LOG_PATHS = [
|
LOG_PATHS = [
|
||||||
f"log/{HOSTNAME}/error",
|
f'log/{HOSTNAME}/error',
|
||||||
]
|
]
|
||||||
for path in LOG_PATHS:
|
for path in LOG_PATHS:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def process_request(func):
|
|||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
return jsonify(result), 200
|
return jsonify(result), 200
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger("error").error(e, exc_info=e)
|
logging.getLogger('error').error(e, exc_info=e)
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
19
tool/auto_generator.py
Normal file
19
tool/auto_generator.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
自动生成数据库表和sqlalchemy对应的Model
|
||||||
|
"""
|
||||||
|
import os.path
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from db import DB_URL
|
||||||
|
from log import PROJECT_ROOT
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
table = input('请输入表名:')
|
||||||
|
out_file = os.path.join(PROJECT_ROOT, 'db', f'{table}.py')
|
||||||
|
command = f'sqlacodegen {DB_URL} --outfile={out_file} --tables={table}'
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(command, shell=True, check=True)
|
||||||
|
print(f'{table}.py文件生成成功!请检查并复制到合适的文件中!')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'生成{table}.py文件时发生错误: {e}')
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
from ufile import config
|
from ufile import config
|
||||||
|
|
||||||
# 公钥
|
# 公钥
|
||||||
PUBLIC_KEY = "4Z7QYI7qml36QRjcCjKrls7aHl1R6H6uq"
|
PUBLIC_KEY = '4Z7QYI7qml36QRjcCjKrls7aHl1R6H6uq'
|
||||||
# 私钥
|
# 私钥
|
||||||
PRIVATE_KEY = "FIdW1Kev1Ge3K7GHXzSLyGG1wTnaG6LE9BxmIVubcCaG"
|
PRIVATE_KEY = 'FIdW1Kev1Ge3K7GHXzSLyGG1wTnaG6LE9BxmIVubcCaG'
|
||||||
# 桶
|
# 桶
|
||||||
BUCKET = "drg100"
|
BUCKET = 'drg100'
|
||||||
# 上传后缀
|
# 上传后缀
|
||||||
UPLOAD_SUFFIX = ".cn-sh2.ufileos.com"
|
UPLOAD_SUFFIX = '.cn-sh2.ufileos.com'
|
||||||
# 下载后缀
|
# 下载后缀
|
||||||
DOWNLOAD_SUFFIX = ".cn-sh2.ufileos.com"
|
DOWNLOAD_SUFFIX = '.cn-sh2.ufileos.com'
|
||||||
# 私空间文件地址过期时间(秒)
|
# 私空间文件地址过期时间(秒)
|
||||||
PRIVATE_EXPIRES = 3600
|
PRIVATE_EXPIRES = 3600
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ def get_private_url(key, bucket=BUCKET):
|
|||||||
# 判断文件是否存在
|
# 判断文件是否存在
|
||||||
_, resp = UFILE_HANDLER.head_file(bucket, key)
|
_, resp = UFILE_HANDLER.head_file(bucket, key)
|
||||||
if resp.status_code == -1:
|
if resp.status_code == -1:
|
||||||
UCLOUD_LOGGER.warning(f"查询({key})时uCloud连接失败!")
|
UCLOUD_LOGGER.warning(f'查询({key})时uCloud连接失败!')
|
||||||
raise ConnectionError("uCloud连接失败")
|
raise ConnectionError('uCloud连接失败')
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
UCLOUD_LOGGER.warning(f"[{bucket}]中未找到({key})! status: {resp.status_code} error: {resp.error}")
|
UCLOUD_LOGGER.warning(f'[{bucket}]中未找到({key})! status: {resp.status_code} error: {resp.error}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取公有空间下载url
|
# 获取公有空间下载url
|
||||||
@@ -37,11 +37,11 @@ def copy_file(source_bucket, source_key, target_bucket, target_key):
|
|||||||
# 复制文件
|
# 复制文件
|
||||||
_, resp = UFILE_HANDLER.copy(target_bucket, target_key, source_bucket, source_key)
|
_, resp = UFILE_HANDLER.copy(target_bucket, target_key, source_bucket, source_key)
|
||||||
if resp.status_code == -1:
|
if resp.status_code == -1:
|
||||||
UCLOUD_LOGGER.warning(f"复制({source_key})时uCloud连接失败!")
|
UCLOUD_LOGGER.warning(f'复制({source_key})时uCloud连接失败!')
|
||||||
return False
|
return False
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
UCLOUD_LOGGER.warning(
|
UCLOUD_LOGGER.warning(
|
||||||
f"将({source_key})从[{source_bucket}]拷贝到[{target_bucket}]失败! status: {resp.status_code} error: {resp.error}"
|
f'将({source_key})从[{source_bucket}]拷贝到[{target_bucket}]失败! status: {resp.status_code} error: {resp.error}'
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -53,9 +53,9 @@ def upload_file(key, file_path, bucket=BUCKET):
|
|||||||
# 普通上传文件至云空间
|
# 普通上传文件至云空间
|
||||||
_, resp = UFILE_HANDLER.putfile(bucket, key, file_path, header=None)
|
_, resp = UFILE_HANDLER.putfile(bucket, key, file_path, header=None)
|
||||||
if resp.status_code == -1:
|
if resp.status_code == -1:
|
||||||
UCLOUD_LOGGER.warning(f"上传({key})时uCloud连接失败!")
|
UCLOUD_LOGGER.warning(f'上传({key})时uCloud连接失败!')
|
||||||
return False
|
return False
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
UCLOUD_LOGGER.warning(f"上传({key})失败! status: {resp.status_code} error: {resp.error}")
|
UCLOUD_LOGGER.warning(f'上传({key})失败! status: {resp.status_code} error: {resp.error}')
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ def delete_temp_file(temp_files):
|
|||||||
for file in temp_files:
|
for file in temp_files:
|
||||||
try:
|
try:
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
logging.info(f"临时文件 {file} 已删除")
|
logging.info(f'临时文件 {file} 已删除')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"删除临时文件 {file} 时出错: {e}")
|
logging.warning(f'删除临时文件 {file} 时出错: {e}')
|
||||||
|
|
||||||
|
|
||||||
def zoom_rectangle(rectangle, ratio):
|
def zoom_rectangle(rectangle, ratio):
|
||||||
@@ -83,40 +83,40 @@ def zoom_rectangle(rectangle, ratio):
|
|||||||
|
|
||||||
|
|
||||||
def chinese_to_money_unit(chinese):
|
def chinese_to_money_unit(chinese):
|
||||||
if chinese in ["拾", "十"]:
|
if chinese in ['拾', '十']:
|
||||||
return 10, False
|
return 10, False
|
||||||
elif chinese in ["佰", "百"]:
|
elif chinese in ['佰', '百']:
|
||||||
return 100, False
|
return 100, False
|
||||||
elif chinese in ["仟", "千"]:
|
elif chinese in ['仟', '千']:
|
||||||
return 1000, False
|
return 1000, False
|
||||||
elif chinese == "万":
|
elif chinese == '万':
|
||||||
return 10000, True
|
return 10000, True
|
||||||
elif chinese == "亿":
|
elif chinese == '亿':
|
||||||
return 100000000, True
|
return 100000000, True
|
||||||
else:
|
else:
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
|
|
||||||
def chinese_char_to_number(chinese):
|
def chinese_char_to_number(chinese):
|
||||||
if chinese == "零":
|
if chinese == '零':
|
||||||
return 0
|
return 0
|
||||||
elif chinese in ["一", "壹"]:
|
elif chinese in ['一', '壹']:
|
||||||
return 1
|
return 1
|
||||||
elif chinese in ["二", "贰"]:
|
elif chinese in ['二', '贰']:
|
||||||
return 2
|
return 2
|
||||||
elif chinese in ["三", "叁"]:
|
elif chinese in ['三', '叁']:
|
||||||
return 3
|
return 3
|
||||||
elif chinese in ["四", "肆"]:
|
elif chinese in ['四', '肆']:
|
||||||
return 4
|
return 4
|
||||||
elif chinese in ["五", "伍"]:
|
elif chinese in ['五', '伍']:
|
||||||
return 5
|
return 5
|
||||||
elif chinese in ["六", "陆"]:
|
elif chinese in ['六', '陆']:
|
||||||
return 6
|
return 6
|
||||||
elif chinese in ["七", "柒"]:
|
elif chinese in ['七', '柒']:
|
||||||
return 7
|
return 7
|
||||||
elif chinese in ["八", "捌"]:
|
elif chinese in ['八', '捌']:
|
||||||
return 8
|
return 8
|
||||||
elif chinese in ["九", "玖"]:
|
elif chinese in ['九', '玖']:
|
||||||
return 9
|
return 9
|
||||||
else:
|
else:
|
||||||
return -1
|
return -1
|
||||||
@@ -137,12 +137,12 @@ def chinese_to_number(chinese):
|
|||||||
section += number * (unit[0] / 10)
|
section += number * (unit[0] / 10)
|
||||||
unit = [None, False]
|
unit = [None, False]
|
||||||
elif number > 0:
|
elif number > 0:
|
||||||
raise ValueError(f"{chinese} has bad number '{chinese[i - 1]}{c}' at: {i}")
|
raise ValueError(f"'{chinese} has bad number '{chinese[i - 1]}{c}' at: {i}'")
|
||||||
number = num
|
number = num
|
||||||
else:
|
else:
|
||||||
unit = chinese_to_money_unit(c)
|
unit = chinese_to_money_unit(c)
|
||||||
if unit[0] is None:
|
if unit[0] is None:
|
||||||
raise ValueError(f"{chinese} has unknown unit '{c}' at: {i}")
|
raise ValueError(f"'{chinese} has unknown unit '{c}' at: {i}'")
|
||||||
if unit[1]:
|
if unit[1]:
|
||||||
section = (section + number) * unit[0]
|
section = (section + number) * unit[0]
|
||||||
result += section
|
result += section
|
||||||
@@ -163,14 +163,14 @@ def chinese_to_number(chinese):
|
|||||||
def chinese_money_to_number(chinese_money_amount):
|
def chinese_money_to_number(chinese_money_amount):
|
||||||
if string_util.blank(chinese_money_amount):
|
if string_util.blank(chinese_money_amount):
|
||||||
return None
|
return None
|
||||||
yi = chinese_money_amount.find("元")
|
yi = chinese_money_amount.find('元')
|
||||||
if yi == -1:
|
if yi == -1:
|
||||||
yi = chinese_money_amount.find("圆")
|
yi = chinese_money_amount.find('圆')
|
||||||
ji = chinese_money_amount.find("角")
|
ji = chinese_money_amount.find('角')
|
||||||
fi = chinese_money_amount.find("分")
|
fi = chinese_money_amount.find('分')
|
||||||
|
|
||||||
if yi == -1 and ji == -1 and fi == -1:
|
if yi == -1 and ji == -1 and fi == -1:
|
||||||
raise ValueError(f"无法解析: {chinese_money_amount}")
|
raise ValueError(f'无法解析: {chinese_money_amount}')
|
||||||
|
|
||||||
y_str = None
|
y_str = None
|
||||||
if yi > 0:
|
if yi > 0:
|
||||||
|
|||||||
@@ -8,20 +8,20 @@ from util import common_util
|
|||||||
# 处理金额类数据
|
# 处理金额类数据
|
||||||
def handle_decimal(string):
|
def handle_decimal(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
string = re.sub(r'[^0-9.]', '', string)
|
string = re.sub(r'[^0-9.]', '', string)
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
if "." not in string:
|
if '.' not in string:
|
||||||
if len(string) > 2:
|
if len(string) > 2:
|
||||||
result = string[:-2] + "." + string[-2:]
|
result = string[:-2] + '.' + string[-2:]
|
||||||
else:
|
else:
|
||||||
result = string
|
result = string
|
||||||
else:
|
else:
|
||||||
front, back = string.rsplit('.', 1)
|
front, back = string.rsplit('.', 1)
|
||||||
front = front.replace(".", "")
|
front = front.replace('.', '')
|
||||||
if back:
|
if back:
|
||||||
back = "." + back[:2]
|
back = '.' + back[:2]
|
||||||
result = front + back
|
result = front + back
|
||||||
return result[:16]
|
return result[:16]
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ def parse_money(capital_num, num):
|
|||||||
money = common_util.chinese_money_to_number(capital_num)
|
money = common_util.chinese_money_to_number(capital_num)
|
||||||
return capital_num, money
|
return capital_num, money
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("大写金额解析失败", exc_info=e)
|
logging.warning('大写金额解析失败', exc_info=e)
|
||||||
|
|
||||||
return num, handle_decimal(num)
|
return num, handle_decimal(num)
|
||||||
|
|
||||||
@@ -40,17 +40,17 @@ def parse_money(capital_num, num):
|
|||||||
# 处理日期类数据
|
# 处理日期类数据
|
||||||
def handle_date(string):
|
def handle_date(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
|
|
||||||
string = string.replace("年", "-").replace("月", "-").replace("日", "").replace("/", "-").replace(".", "-")
|
string = string.replace('年', '-').replace('月', '-').replace('日', '').replace('/', '-').replace('.', '-')
|
||||||
string = re.sub(r'[^0-9-]', '', string)
|
string = re.sub(r'[^0-9-]', '', string)
|
||||||
string = string.strip("-")
|
string = string.strip('-')
|
||||||
if "-" in string:
|
if '-' in string:
|
||||||
dash_count = string.count("-")
|
dash_count = string.count('-')
|
||||||
if dash_count > 2:
|
if dash_count > 2:
|
||||||
third_dash_index = string.find("-", string.find("-", string.find("-") + 1) + 1)
|
third_dash_index = string.find('-', string.find('-', string.find('-') + 1) + 1)
|
||||||
string = string[:third_dash_index]
|
string = string[:third_dash_index]
|
||||||
day = string[string.rindex("-") + 1:]
|
day = string[string.rindex('-') + 1:]
|
||||||
if len(day) > 2:
|
if len(day) > 2:
|
||||||
string = string[:2 - len(day)]
|
string = string[:2 - len(day)]
|
||||||
else:
|
else:
|
||||||
@@ -58,7 +58,7 @@ def handle_date(string):
|
|||||||
string = string[:8]
|
string = string[:8]
|
||||||
|
|
||||||
if len(string) < 6:
|
if len(string) < 6:
|
||||||
return ""
|
return ''
|
||||||
|
|
||||||
# 定义可能的日期格式
|
# 定义可能的日期格式
|
||||||
formats = [
|
formats = [
|
||||||
@@ -78,23 +78,23 @@ def handle_date(string):
|
|||||||
date = datetime.strptime(string, fmt)
|
date = datetime.strptime(string, fmt)
|
||||||
# 限定日期的年份范围
|
# 限定日期的年份范围
|
||||||
if 2000 < date.year < 2100:
|
if 2000 < date.year < 2100:
|
||||||
return date.strftime("%Y-%m-%d")
|
return date.strftime('%Y-%m-%d')
|
||||||
continue
|
continue
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return ""
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def handle_hospital(string):
|
def handle_hospital(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
return string[:255]
|
return string[:255]
|
||||||
|
|
||||||
|
|
||||||
def handle_department(string):
|
def handle_department(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
return string[:255]
|
return string[:255]
|
||||||
|
|
||||||
|
|
||||||
@@ -103,12 +103,12 @@ def parse_department(string):
|
|||||||
if not string:
|
if not string:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
string = string.replace(")", "").replace(")", "").replace("(", " ").replace("(", " ") # 去除括号
|
string = string.replace(')', '').replace(')', '').replace('(', ' ').replace('(', ' ') # 去除括号
|
||||||
string = re.sub(r'[^⺀-鿿 ]', '', string) # 去除非汉字字符,除了空格
|
string = re.sub(r'[^⺀-鿿 ]', '', string) # 去除非汉字字符,除了空格
|
||||||
string = re.sub(r'[一二三四五六七八九十]', '', string) # 去除中文数字
|
string = re.sub(r'[一二三四五六七八九十]', '', string) # 去除中文数字
|
||||||
string = string.replace("病区", "").replace("病", "") # 去除常见的无意义词
|
string = string.replace('病区', '').replace('病', '') # 去除常见的无意义词
|
||||||
string = string.replace("科", " ") # 分离科室
|
string = string.replace('科', ' ') # 分离科室
|
||||||
departments = string.strip().split(" ")
|
departments = string.strip().split(' ')
|
||||||
for department in departments:
|
for department in departments:
|
||||||
if department:
|
if department:
|
||||||
result.append(department)
|
result.append(department)
|
||||||
@@ -118,33 +118,33 @@ def parse_department(string):
|
|||||||
# 处理姓名类数据
|
# 处理姓名类数据
|
||||||
def handle_name(string):
|
def handle_name(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
return re.sub(r'[^⺀-鿿·]', '', string)[:30]
|
return re.sub(r'[^⺀-鿿·]', '', string)[:30]
|
||||||
|
|
||||||
|
|
||||||
# 处理医保类型数据
|
# 处理医保类型数据
|
||||||
def handle_insurance_type(string):
|
def handle_insurance_type(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
worker_insurance_keys = ["社保", "城保", "职", "退休"]
|
worker_insurance_keys = ['社保', '城保', '职', '退休']
|
||||||
villager_insurance_keys = ["农保", "居民"]
|
villager_insurance_keys = ['农保', '居民']
|
||||||
migrant_worker_insurance_keys = ["农民工"]
|
migrant_worker_insurance_keys = ['农民工']
|
||||||
no_insurance_keys = ["自费", "全费"]
|
no_insurance_keys = ['自费', '全费']
|
||||||
if any(key in string for key in worker_insurance_keys):
|
if any(key in string for key in worker_insurance_keys):
|
||||||
return "职工医保"
|
return '职工医保'
|
||||||
if any(key in string for key in villager_insurance_keys):
|
if any(key in string for key in villager_insurance_keys):
|
||||||
return "居民医保"
|
return '居民医保'
|
||||||
if any(key in string for key in migrant_worker_insurance_keys):
|
if any(key in string for key in migrant_worker_insurance_keys):
|
||||||
return "农民工医保"
|
return '农民工医保'
|
||||||
if any(key in string for key in no_insurance_keys):
|
if any(key in string for key in no_insurance_keys):
|
||||||
return "无医保"
|
return '无医保'
|
||||||
return "其他"
|
return '其他'
|
||||||
|
|
||||||
|
|
||||||
# 处理原始数据
|
# 处理原始数据
|
||||||
def handle_original_data(string):
|
def handle_original_data(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
# 防止过长存入数据库失败
|
# 防止过长存入数据库失败
|
||||||
return string[:255]
|
return string[:255]
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ def handle_original_data(string):
|
|||||||
# 处理id类数据
|
# 处理id类数据
|
||||||
def handle_id(string):
|
def handle_id(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
# 防止过长存入数据库失败
|
# 防止过长存入数据库失败
|
||||||
return string[:50]
|
return string[:50]
|
||||||
|
|
||||||
@@ -160,8 +160,8 @@ def handle_id(string):
|
|||||||
# 处理年龄类数据
|
# 处理年龄类数据
|
||||||
def handle_age(string):
|
def handle_age(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ''
|
||||||
string = string.split("岁")[0]
|
string = string.split('岁')[0]
|
||||||
num = re.sub(r'\D', '', string)
|
num = re.sub(r'\D', '', string)
|
||||||
return num[-3:]
|
return num[-3:]
|
||||||
|
|
||||||
@@ -173,8 +173,8 @@ def parse_hospital(string):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
string = common_util.traditional_to_simple_chinese(string)
|
string = common_util.traditional_to_simple_chinese(string)
|
||||||
string_without_brackets = string.replace(")", "").replace(")", "").replace("(", " ").replace("(", " ")
|
string_without_brackets = string.replace(')', '').replace(')', '').replace('(', ' ').replace('(', ' ')
|
||||||
string_without_company = string_without_brackets.replace("有限公司", "")
|
string_without_company = string_without_brackets.replace('有限公司', '')
|
||||||
split_hospitals = string_without_company.replace("医院", "医院 ")
|
split_hospitals = string_without_company.replace('医院', '医院 ')
|
||||||
result += split_hospitals.strip().split(" ")
|
result += split_hospitals.strip().split(' ')
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def get_jsczt_id_base(url):
|
|||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception(f'请求江苏省财政票据失败!状态码: {response.status_code}')
|
raise Exception(f'请求江苏省财政票据失败!状态码: {response.status_code}')
|
||||||
soup = BeautifulSoup(response.text, 'html.parser')
|
soup = BeautifulSoup(response.text, 'html.parser')
|
||||||
hidden_input = soup.find('input', {'name': "idBase"})
|
hidden_input = soup.find('input', {'name': 'idBase'})
|
||||||
if hidden_input:
|
if hidden_input:
|
||||||
# 获取隐藏字段的值
|
# 获取隐藏字段的值
|
||||||
value = hidden_input.get('value')
|
value = hidden_input.get('value')
|
||||||
|
|||||||
@@ -13,14 +13,14 @@ from log import PROJECT_ROOT
|
|||||||
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||||
after=lambda x: logging.warning("获取图片失败!"))
|
after=lambda x: logging.warning('获取图片失败!'))
|
||||||
def read(image_path):
|
def read(image_path):
|
||||||
"""
|
"""
|
||||||
从网络或本地读取图片
|
从网络或本地读取图片
|
||||||
:param image_path: 网络或本地路径
|
:param image_path: 网络或本地路径
|
||||||
:return: NumPy数组形式的图片
|
:return: NumPy数组形式的图片
|
||||||
"""
|
"""
|
||||||
if image_path.startswith("http"):
|
if image_path.startswith('http'):
|
||||||
# 发送HTTP请求并获取图像数据
|
# 发送HTTP请求并获取图像数据
|
||||||
resp = urllib.request.urlopen(image_path, timeout=60)
|
resp = urllib.request.urlopen(image_path, timeout=60)
|
||||||
# 将数据读取为字节流
|
# 将数据读取为字节流
|
||||||
@@ -76,35 +76,35 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
|
|||||||
for i in range(math.ceil(height / step)):
|
for i in range(math.ceil(height / step)):
|
||||||
offset = round(step * i)
|
offset = round(step * i)
|
||||||
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
|
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
|
||||||
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
|
split_result.append({'img': cropped_img, 'x_offset': 0, 'y_offset': offset})
|
||||||
elif wh_ratio > ratio: # 横向过长
|
elif wh_ratio > ratio: # 横向过长
|
||||||
new_img_width = height * ratio
|
new_img_width = height * ratio
|
||||||
step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分
|
step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分
|
||||||
for i in range(math.ceil(width / step)):
|
for i in range(math.ceil(width / step)):
|
||||||
offset = round(step * i)
|
offset = round(step * i)
|
||||||
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
|
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
|
||||||
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
|
split_result.append({'img': cropped_img, 'x_offset': offset, 'y_offset': 0})
|
||||||
else:
|
else:
|
||||||
split_result.append({"img": image, "x_offset": 0, "y_offset": 0})
|
split_result.append({'img': image, 'x_offset': 0, 'y_offset': 0})
|
||||||
return split_result
|
return split_result
|
||||||
|
|
||||||
|
|
||||||
def parse_rotation_angles(image):
|
def parse_rotation_angles(image):
|
||||||
"""
|
"""
|
||||||
判断图片旋转角度,逆时针旋转该角度后为正。可能值["0", "90", "180", "270"]
|
判断图片旋转角度,逆时针旋转该角度后为正。可能值['0', '90', '180', '270']
|
||||||
:param image: 图片NumPy数组或文件路径
|
:param image: 图片NumPy数组或文件路径
|
||||||
:return: 最有可能的两个角度
|
:return: 最有可能的两个角度
|
||||||
"""
|
"""
|
||||||
angles = ['0', '90']
|
angles = ['0', '90']
|
||||||
model = PaddleClas(model_name="text_image_orientation")
|
model = PaddleClas(model_name='text_image_orientation')
|
||||||
clas_result = model.predict(input_data=image)
|
clas_result = model.predict(input_data=image)
|
||||||
try:
|
try:
|
||||||
clas_result = next(clas_result)[0]
|
clas_result = next(clas_result)[0]
|
||||||
if clas_result["scores"][0] < 0.5:
|
if clas_result['scores'][0] < 0.5:
|
||||||
return angles
|
return angles
|
||||||
angles = clas_result["label_names"]
|
angles = clas_result['label_names']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("获取图片旋转角度失败", exc_info=e)
|
logging.error('获取图片旋转角度失败', exc_info=e)
|
||||||
return angles
|
return angles
|
||||||
|
|
||||||
|
|
||||||
@@ -201,25 +201,25 @@ def expand_to_a4_size(image):
|
|||||||
if hw_ratio >= 1.42:
|
if hw_ratio >= 1.42:
|
||||||
exp_w = int(height / 1.414 - width)
|
exp_w = int(height / 1.414 - width)
|
||||||
x_offset = int(exp_w / 2)
|
x_offset = int(exp_w / 2)
|
||||||
exp_img = numpy.zeros((height, x_offset, 3), dtype="uint8")
|
exp_img = numpy.zeros((height, x_offset, 3), dtype='uint8')
|
||||||
exp_img.fill(255)
|
exp_img.fill(255)
|
||||||
image = numpy.hstack([exp_img, image, exp_img])
|
image = numpy.hstack([exp_img, image, exp_img])
|
||||||
elif 1 <= hw_ratio <= 1.40:
|
elif 1 <= hw_ratio <= 1.40:
|
||||||
exp_h = int(width * 1.414 - height)
|
exp_h = int(width * 1.414 - height)
|
||||||
y_offset = int(exp_h / 2)
|
y_offset = int(exp_h / 2)
|
||||||
exp_img = numpy.zeros((y_offset, width, 3), dtype="uint8")
|
exp_img = numpy.zeros((y_offset, width, 3), dtype='uint8')
|
||||||
exp_img.fill(255)
|
exp_img.fill(255)
|
||||||
image = numpy.vstack([exp_img, image, exp_img])
|
image = numpy.vstack([exp_img, image, exp_img])
|
||||||
elif 0.72 <= hw_ratio < 1:
|
elif 0.72 <= hw_ratio < 1:
|
||||||
exp_w = int(height * 1.414 - width)
|
exp_w = int(height * 1.414 - width)
|
||||||
x_offset = int(exp_w / 2)
|
x_offset = int(exp_w / 2)
|
||||||
exp_img = numpy.zeros((height, x_offset, 3), dtype="uint8")
|
exp_img = numpy.zeros((height, x_offset, 3), dtype='uint8')
|
||||||
exp_img.fill(255)
|
exp_img.fill(255)
|
||||||
image = numpy.hstack([exp_img, image, exp_img])
|
image = numpy.hstack([exp_img, image, exp_img])
|
||||||
elif hw_ratio <= 0.7:
|
elif hw_ratio <= 0.7:
|
||||||
exp_h = int(width / 1.414 - height)
|
exp_h = int(width / 1.414 - height)
|
||||||
y_offset = int(exp_h / 2)
|
y_offset = int(exp_h / 2)
|
||||||
exp_img = numpy.zeros((y_offset, width, 3), dtype="uint8")
|
exp_img = numpy.zeros((y_offset, width, 3), dtype='uint8')
|
||||||
exp_img.fill(255)
|
exp_img.fill(255)
|
||||||
image = numpy.vstack([exp_img, image, exp_img])
|
image = numpy.vstack([exp_img, image, exp_img])
|
||||||
return image, x_offset, y_offset
|
return image, x_offset, y_offset
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 2.9 MiB |
@@ -1,163 +0,0 @@
|
|||||||
# 可视化的模型对比测试
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from paddlenlp import Taskflow
|
|
||||||
from paddlenlp.utils.doc_parser import DocParser
|
|
||||||
from paddleocr import PaddleOCR
|
|
||||||
|
|
||||||
from ucloud import ufile
|
|
||||||
from util import image_util, common_util
|
|
||||||
|
|
||||||
|
|
||||||
def write_visual_result(image, angle=0, layout=None, result=None):
|
|
||||||
img = image.split("?")[0]
|
|
||||||
img = re.split(r'[\\/]', img)[-1]
|
|
||||||
img_name = ""
|
|
||||||
img_type = "jpg"
|
|
||||||
last_dot_index = img.rfind(".")
|
|
||||||
if last_dot_index != -1:
|
|
||||||
img_name = img[:last_dot_index]
|
|
||||||
img_type = img[last_dot_index + 1:]
|
|
||||||
|
|
||||||
img_array = image_util.read(image)
|
|
||||||
if angle != 0:
|
|
||||||
img_array = image_util.rotate(img_array, angle)
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
|
||||||
cv2.imwrite(temp_file.name, img_array)
|
|
||||||
if layout:
|
|
||||||
print(layout)
|
|
||||||
DocParser.write_image_with_results(
|
|
||||||
temp_file.name,
|
|
||||||
layout=layout,
|
|
||||||
save_path="./img_result/" + img_name + "_layout." + img_type)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
print(result)
|
|
||||||
DocParser.write_image_with_results(
|
|
||||||
temp_file.name,
|
|
||||||
result=result,
|
|
||||||
save_path="./img_result/" + img_name + "_result." + img_type)
|
|
||||||
os.remove(temp_file.name)
|
|
||||||
|
|
||||||
|
|
||||||
def visual_model_test(model_type, test_img, task_path, schema):
|
|
||||||
if model_type == "ocr":
|
|
||||||
imgs = image_util.split(test_img)
|
|
||||||
layout = []
|
|
||||||
temp_files_paths = []
|
|
||||||
# doc_parser = DocParser(layout_analysis=False)
|
|
||||||
for img in imgs:
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
|
||||||
# angle = image_util.parse_rotation_angles(img["img"])[0]
|
|
||||||
angle = 0
|
|
||||||
rotated_img = image_util.rotate(img["img"], angle)
|
|
||||||
rotated_img, offset_x, offset_y = image_util.expand_to_a4_size(rotated_img)
|
|
||||||
cv2.imwrite(temp_file.name, rotated_img)
|
|
||||||
|
|
||||||
img["x_offset"] -= offset_x
|
|
||||||
img["y_offset"] -= offset_y
|
|
||||||
|
|
||||||
temp_files_paths.append(temp_file.name)
|
|
||||||
parsed_doc = common_util.get_ocr_layout(
|
|
||||||
PaddleOCR(det_db_box_thresh=0.3, det_db_thresh=0.1, det_limit_side_len=1248, drop_score=0.3,
|
|
||||||
save_crop_res=False),
|
|
||||||
temp_file.name)
|
|
||||||
# parsed_doc = doc_parser.parse({"doc": temp_file.name})["layout"]
|
|
||||||
if img["x_offset"] or img["y_offset"]:
|
|
||||||
for p in parsed_doc:
|
|
||||||
box = p[0]
|
|
||||||
box[0] += img["x_offset"]
|
|
||||||
box[1] += img["y_offset"]
|
|
||||||
box[2] += img["x_offset"]
|
|
||||||
box[3] += img["y_offset"]
|
|
||||||
layout += parsed_doc
|
|
||||||
|
|
||||||
write_visual_result(test_img, angle, layout=layout)
|
|
||||||
else:
|
|
||||||
docs = []
|
|
||||||
split_result = image_util.split(test_img)
|
|
||||||
temp_files_paths = []
|
|
||||||
for img in split_result:
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
|
||||||
angle = int(image_util.parse_rotation_angles(img["img"])[0])
|
|
||||||
rotated_img = image_util.rotate(img["img"], angle)
|
|
||||||
cv2.imwrite(temp_file.name, rotated_img)
|
|
||||||
temp_files_paths.append(temp_file.name)
|
|
||||||
docs.append({"doc": temp_file.name})
|
|
||||||
|
|
||||||
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
|
||||||
layout_analysis=False)
|
|
||||||
my_results = my_ie(docs)
|
|
||||||
write_visual_result(test_img, angle, result=my_results[0])
|
|
||||||
|
|
||||||
# 使用完临时文件后,记得清理(删除)它们
|
|
||||||
for path in temp_files_paths:
|
|
||||||
try:
|
|
||||||
os.remove(path)
|
|
||||||
print(f"临时文件 {path} 已删除")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"删除临时文件 {path} 时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def batch_test(test_imgs, task_path, schema):
|
|
||||||
docs = []
|
|
||||||
for test_img in test_imgs:
|
|
||||||
docs.append({"doc": test_img})
|
|
||||||
my_ie = Taskflow("information_extraction", schema=schema, model="uie-x-base", task_path=task_path,
|
|
||||||
layout_analysis=False, batch_size=16)
|
|
||||||
# 批量抽取写法:(ie([{"doc": "./data/6.jpg"}, {"doc": "./data/7.jpg"}])
|
|
||||||
my_results = my_ie(docs)
|
|
||||||
pprint(my_results)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model_type, pic_name=None):
|
|
||||||
# 开始时间
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
if model_type == "ocr":
|
|
||||||
task_path = None
|
|
||||||
test_img_path = ufile.get_private_url(pic_name,
|
|
||||||
"drg103") if pic_name else "../test_img/PH20240725004467_3_185708_1.jpg"
|
|
||||||
schema = None
|
|
||||||
elif model_type == "settlement":
|
|
||||||
task_path = "../services/paddle_services/model/settlement_list_model"
|
|
||||||
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000638_1_094306_1.jpg"
|
|
||||||
schema = ["患者姓名", "入院日期", "出院日期", "费用总额", "个人现金支付", "个人账户支付", "自费金额",
|
|
||||||
"医保类型", "住院号", "医保结算单号码", "大写总额"]
|
|
||||||
elif model_type == "discharge":
|
|
||||||
task_path = "../services/paddle_services/model/discharge_record_model"
|
|
||||||
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240401000003_3_001938_2.jpg"
|
|
||||||
schema = ["医院", "科室", "患者姓名", "入院日期", "出院日期", "主治医生", "住院号", "年龄"]
|
|
||||||
elif model_type == "cost":
|
|
||||||
task_path = "../services/paddle_services/model/cost_list_model"
|
|
||||||
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
|
|
||||||
schema = ["患者姓名", "入院日期", "出院日期", "费用总额"]
|
|
||||||
elif model_type == "cost_detail":
|
|
||||||
task_path = "../services/paddle_services/model/cost_list_detail_model"
|
|
||||||
test_img_path = ufile.get_private_url(pic_name) if pic_name else "img/PH20240511000648_4_094542_2.jpg"
|
|
||||||
schema = {"名称": ["类别", "规格", "单价", "数量", "金额"]}
|
|
||||||
else:
|
|
||||||
print("请输入正确的类型!")
|
|
||||||
return
|
|
||||||
visual_model_test(model_type, test_img_path, task_path, schema)
|
|
||||||
|
|
||||||
# 结束时间
|
|
||||||
end_time = time.time()
|
|
||||||
pprint(f"处理时长:{end_time - start_time}秒")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main("ocr")
|
|
||||||
# main("settlement")
|
|
||||||
# main("discharge")
|
|
||||||
# main("cost")
|
|
||||||
# main("cost_detail")
|
|
||||||
# write_visual_result("img/PH20240428000832_1_093844_2.jpg", layout=[([508.0975609756094,
|
|
||||||
# 659.7073170731707,
|
|
||||||
# 1000,
|
|
||||||
# 745.756097560976], 'lay', 'figure')])
|
|
||||||
Reference in New Issue
Block a user