124 Commits
dev ... deploy

Author SHA1 Message Date
3044f1fc08 修正heic图片保存问题 2024-10-21 09:32:39 +08:00
da18d890f7 删除多余导入 2024-10-21 09:02:30 +08:00
ae52d0594e 清除临时图片另外写脚本执行 2024-10-21 08:45:39 +08:00
d51e56b8f2 东台医院接口暂未处理好 2024-10-21 08:37:25 +08:00
83339b5e58 延长接口服务等待时间 2024-10-18 16:28:52 +08:00
727743d20e 新增二维码无效地址,优化页码判空 2024-10-18 16:21:39 +08:00
814730a0f0 二维码识别新增南通市中医院 2024-10-18 16:14:00 +08:00
c9894d257e 减少进程限制 2024-10-18 14:45:08 +08:00
68043e5773 提高进程限制 2024-10-18 14:30:08 +08:00
fe58bb3bfa 信息抽取结果判空 2024-10-18 14:08:13 +08:00
ce44a81603 调整显卡分配 2024-10-18 14:05:53 +08:00
4b90bf6dfa 测试全部启用新版效果 2024-10-18 13:55:47 +08:00
248e49bf4b 三项资料分类效果不佳,暂时停用。按数据库存储来分类 2024-10-18 13:39:01 +08:00
401954dca0 优化三项资料分类方法 2024-10-18 13:04:42 +08:00
6529dc3d98 提高文本分类精度 2024-10-18 11:02:58 +08:00
3f93bd476a 修正页码排序 2024-10-18 10:49:01 +08:00
d85b3fff8f 优化页码解析 2024-10-18 10:24:52 +08:00
61a7802674 调整文本分类的判空 2024-10-18 10:19:33 +08:00
9556da47e9 过滤页码异常的值 2024-10-18 10:05:38 +08:00
3710450221 文本分类增加判空 2024-10-18 09:42:51 +08:00
27a4395ca0 调整出院记录缺页判断关键词 2024-10-18 09:37:43 +08:00
f116798c30 缺页提示排序 2024-10-18 09:06:49 +08:00
8c47beb00c 调整文本分类的精度,fp16也许只能接收1024个字符 2024-10-18 08:56:12 +08:00
74920869e7 修正页码操作时的数据类型 2024-10-17 17:29:39 +08:00
9d0db073d6 延长等待接口服务的时间 2024-10-17 17:12:19 +08:00
8e7745f1f6 使用gpu运行文档检测模型 2024-10-17 17:08:29 +08:00
cc53243647 修正ocr没有识别到内容导致的空指针错误 2024-10-17 16:59:10 +08:00
39da0d8a00 自动审核、自动涂抹测试 2024-10-17 16:43:51 +08:00
a2e1f10261 费用清单不好准确定位,暂不替换 2024-10-17 15:38:14 +08:00
e1bd9f3786 更新zx_ie_cost 2024-10-17 15:19:28 +08:00
46f295d422 修正图片分割为空的问题 2024-10-17 15:18:06 +08:00
1a0caf30d0 限制文本分类的文本长度 2024-10-17 15:13:41 +08:00
25df420be8 优化数据类型 2024-10-17 15:03:29 +08:00
b5dffaf5bd 修正页码分析 2024-10-17 14:55:19 +08:00
0e4cfd10b6 自动审核、自动涂抹联动测试 2024-10-17 14:48:58 +08:00
f98969d957 修正布尔类型的运算 2024-10-17 14:34:31 +08:00
0c9bed8661 修正出院记录的合并处理 2024-10-17 14:29:12 +08:00
d0b4a77817 单独测试自动审核 2024-10-17 14:22:44 +08:00
00e5ca7c30 单独测试自动审核 2024-10-17 14:17:47 +08:00
5dee4ed568 添加构建镜像配置 2024-10-17 14:04:43 +08:00
06869e691f 单独测试自动涂抹 2024-10-17 13:41:52 +08:00
8e06fdafa0 清除因程序错误遗留的临时图片 2024-10-17 13:38:29 +08:00
84d106c7de 优化自动涂抹的图片删除逻辑 2024-10-17 13:12:07 +08:00
9c41fab95c 优化含pdf时出院记录的处理 2024-10-17 12:58:23 +08:00
0060c4ad59 修正细节错误 2024-10-17 10:38:29 +08:00
d374e0743a 优化图片分类和图片方向判断 2024-10-16 17:01:56 +08:00
947b4f20f3 优化方向选择的条件 2024-10-15 15:04:24 +08:00
445d57e8c6 优化方向选择的条件 2024-10-15 14:21:23 +08:00
b09f16fe23 调整图片删除位置 2024-10-15 13:45:11 +08:00
c28fc62d3f 修正全局变量的获取 2024-10-15 13:27:55 +08:00
b332aa00dd 优化金额处理 2024-10-15 11:29:06 +08:00
5af6256376 优化图片的存储,及时删除处理过程中产生的图片 2024-10-15 10:17:01 +08:00
15ea3ff96f 更新zx_ie_result 2024-10-14 13:03:14 +08:00
19237d3a3c 修正页码解析 2024-10-12 15:46:50 +08:00
0b0882d456 添加接口日志 2024-10-12 15:26:13 +08:00
304f6897f0 Revert "使用更小的基础镜像"
This reverts commit a9f172fdb0.
2024-10-12 13:52:38 +08:00
a9f172fdb0 使用更小的基础镜像 2024-10-12 13:50:31 +08:00
ac4e4ff8f8 修正服务依赖 2024-10-12 13:44:45 +08:00
f7fbe709bf 打开自动涂抹 2024-10-12 13:37:46 +08:00
396550058f 添加beautifulsoup4依赖 2024-10-12 13:31:26 +08:00
b9ac638b38 删除无效方法 2024-10-12 13:29:14 +08:00
894cab4f0b 添加rapidfuzz依赖 2024-10-12 13:17:11 +08:00
bb6d9c3b47 添加pymupdf依赖 2024-10-12 13:13:06 +08:00
f8280e87ee 修正镜像源替换 2024-10-12 12:56:56 +08:00
608a647621 优化镜像分层 2024-10-12 11:32:09 +08:00
7b9d9ca589 apt配置国内镜像源 2024-10-12 11:28:31 +08:00
d9b24e906d 修复libGL.so.1库缺失 2024-10-12 11:16:32 +08:00
97c7b2cfce 添加缺失依赖 2024-10-12 11:13:00 +08:00
004dd12004 更换基础镜像 2024-10-12 11:09:18 +08:00
cc9d020008 基础镜像使用的是debian系统,修正下载命令行 2024-10-12 10:48:06 +08:00
7335553080 去除提权操作 2024-10-12 10:44:16 +08:00
ebb10b2816 修复libGL.so.1库缺失 2024-10-12 10:42:40 +08:00
98fb9fa861 添加缺失依赖 2024-10-12 10:36:47 +08:00
c75415164e 添加dockerignore配置 2024-10-12 10:27:01 +08:00
03d8652b8f 添加缺失依赖 2024-10-12 10:23:51 +08:00
e3be5cf4b2 修正服务依赖 2024-10-12 10:18:08 +08:00
c92b549480 打开自动识别 2024-10-12 10:14:31 +08:00
d36740d729 修正装饰器错误 2024-10-12 10:03:39 +08:00
a1dea6f29c 统一照片脱敏的图片流转方式 2024-10-11 15:17:26 +08:00
0fc0c80d6f 修正出院记录缺页判断条件 2024-10-11 10:26:54 +08:00
f3930cc7bd 优化自动审核判断 2024-10-11 10:03:20 +08:00
a11cefb999 修正部分英文拼写;修正图片传递;修正页码解析 2024-10-10 15:36:46 +08:00
5c0fc0f819 修正list值的获取 2024-10-10 14:45:39 +08:00
77010f0598 调整出院记录缺页关键词 2024-10-10 14:08:39 +08:00
e4b58e30c0 补充缺页判断 2024-10-10 11:24:16 +08:00
15fe5d4f0d 添加信息抽取存表,根据抽取结果进行缺项判断 2024-10-10 09:24:09 +08:00
fc69aa5b9d 修正已知错误 2024-10-09 14:50:02 +08:00
795134f566 优化案子处理逻辑 2024-10-09 09:39:29 +08:00
a3fa1e502e 调整图片名 2024-09-29 13:55:37 +08:00
7a4cb5263a 添加文本信息抽取接口 2024-09-27 15:31:11 +08:00
46be9a26be 初步添加自动审核功能 2024-09-27 14:53:16 +08:00
f1149854ce 统一模型接口,新增文本分类接口 2024-09-27 13:50:55 +08:00
117b29a737 修正包的导入 2024-09-27 08:52:47 +08:00
3219f28934 将services视为包 2024-09-26 17:22:51 +08:00
2e1c0a57c7 修正方法调用 2024-09-26 17:20:38 +08:00
2dcd2d2a34 添加目录说明 2024-09-26 17:13:37 +08:00
153eb70f84 测试接口服务能否成功启动 2024-09-26 17:05:07 +08:00
b5aba0418b 修正模型地址 2024-09-26 16:56:48 +08:00
603b027ca6 修正新镜像构建文件 2024-09-26 16:47:59 +08:00
d4c54b04f5 构建新镜像 2024-09-26 16:43:34 +08:00
fc3e7b4ed4 接口添加日志 2024-09-26 16:41:30 +08:00
a62c2af816 修正图片矫正接口ip 2024-09-26 16:23:48 +08:00
0618754da2 测试扭曲矫正模型 2024-09-26 16:04:50 +08:00
c5a03ad16f 统一引号格式,优化架构排布 2024-09-26 15:16:57 +08:00
ff9d612e67 调整容器挂载文件 2024-09-26 13:44:05 +08:00
86d28096d4 修改基础镜像 2024-09-26 11:25:05 +08:00
87180cd282 升级opencc 2024-09-26 10:55:18 +08:00
f13ffd1fe9 调整基础镜像 2024-09-26 10:51:06 +08:00
09f62b36a9 修正command写法 2024-09-26 09:43:05 +08:00
186cab0317 调试docker-compose启动 2024-09-25 16:29:56 +08:00
101b2126f4 测试期间删除固定的gunicorn启动 2024-09-25 16:24:56 +08:00
d4a695e9ea 添加gunicorn依赖 2024-09-25 16:14:00 +08:00
72794f699e 修正docker镜像构建 2024-09-25 16:07:12 +08:00
3189caf7aa 添加测试功能 2024-09-25 15:19:45 +08:00
b8c1202957 项目架构调整,模型全部采用接口调用 2024-09-25 14:46:37 +08:00
7647df7d74 移动doc_dewarp 2024-09-24 17:10:56 +08:00
3438cf6e0e 移动paddle_detection 2024-09-24 17:02:56 +08:00
90a6d5ec75 矫正扭曲和图片方向分类接口化 2024-09-24 08:36:34 +08:00
9c21152823 修正url及图片保存 2024-09-23 15:32:16 +08:00
c091a82a91 优化接口图片传输方式 2024-09-23 14:45:03 +08:00
a2a82df21c 优化flask接口 2024-09-20 14:47:43 +08:00
f0c03e763b 优化命名,类与模块最好不要重名 2024-09-20 14:32:31 +08:00
7b6e78373c 删除重复check_ie_result 2024-09-20 12:40:08 +08:00
65b7126348 删除无效配置 2024-09-20 11:32:37 +08:00
2125 changed files with 4263 additions and 2009 deletions

View File

@@ -238,8 +238,11 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# 模型通过卷绑定挂载到容器中
/model
# 通过卷绑定挂载到容器中
/log
/services/paddle_services/log
/services/paddle_services/model
/tmp_img
# docker
Dockerfile
docker-compose*.yml

8
.gitignore vendored
View File

@@ -142,7 +142,11 @@ cython_debug/
.idea
### Model
model
services/paddle_services/model
### Log Backups
*.log.*-*-*
*.log.*-*-*
### Tmp Files
/tmp_img
/test_img

View File

@@ -1,5 +1,5 @@
# 使用官方的paddle镜像作为基础
FROM registry.baidubce.com/paddlepaddle/paddle:2.6.1-gpu-cuda12.0-cudnn8.9-trt8.6
# 使用官方的python镜像作为基础
FROM python:3.10.15-bookworm
# 设置工作目录
WORKDIR /app
@@ -13,12 +13,10 @@ ENV PYTHONUNBUFFERED=1 \
# 安装依赖
COPY requirements.txt /app/requirements.txt
COPY packages /app/packages
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
&& python3 -m pip install --upgrade pip \
&& pip install --no-cache-dir -r requirements.txt \
&& pip uninstall -y onnxruntime onnxruntime-gpu \
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
&& sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
&& apt-get update && apt-get install libgl1 -y \
&& pip install --no-cache-dir -r requirements.txt
# 将当前目录内容复制到容器的/app内
COPY . /app

View File

@@ -1,33 +0,0 @@
# 使用官方的paddle镜像作为基础
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
# 设置时区
TZ=Asia/Shanghai \
# 设置pip镜像地址加快安装速度
PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
# 安装language-pack-en和openssh-server
RUN apt update && \
apt install -y language-pack-en && \
apt install -y openssh-server
# 配置SSH服务
RUN mkdir /var/run/sshd && \
# 设置root密码可根据需要修改
echo 'root:fcb0102' | chpasswd && \
# 允许root登录SSH
sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
# 将当前目录内容复制到容器的/app内
COPY . /app
# 暴露22端口
EXPOSE 22
# 启动SSH服务
CMD ["/usr/sbin/sshd", "-D"]

View File

@@ -6,7 +6,7 @@
1. 从Git远程仓库克隆项目到本地。
2. 将深度学习模型复制到./model目录下具体请看[模型更新](#模型更新)部分。
2. 将深度学习模型复制到./services/paddle_services/model目录下具体请看[模型更新](#模型更新)部分。
3. 安装docker和docker-compose。
@@ -126,7 +126,5 @@ bash update.sh
2. 新增扭曲矫正功能
21. 版本号1.14.0
1. 新增二维码识别替换高清图片功能
22. 版本号:1.15.0
1. 新增图片清晰度测试
23. 版本号1.16.0
1. 更新paddle框架至3.0
22. 版本号:2.0.0
1. 项目架构调整,模型全部采用接口调用

4
api_test.py Normal file
View File

@@ -0,0 +1,4 @@
import time
if __name__ == '__main__':
time.sleep(3600)

View File

@@ -1,15 +0,0 @@
# 自动生成数据库表和sqlalchemy对应的Model
import subprocess
from db import DB_URL
if __name__ == '__main__':
table = input("请输入表名:")
out_file = f"db/{table}.py"
command = f"sqlacodegen {DB_URL} --outfile={out_file} --tables={table}"
try:
subprocess.run(command, shell=True, check=True)
print(f"{table}.py文件生成成功请检查并复制到合适的文件中")
except Exception as e:
print(f"生成{table}.py文件时发生错误: {e}")

View File

@@ -1 +0,0 @@

View File

@@ -1,108 +0,0 @@
import datetime
import json
import os
from decimal import Decimal
from io import BytesIO
from itertools import groupby
import requests
from PIL import ImageDraw, Image, ImageFont
from db import MysqlSession
from db.mysql import ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxIeResult, ZxPhrec
from ucloud import ufile
from util import image_util
def check_ie_result(pk_phhd, need_to_annotation=True):
os.makedirs(f"./check_result/{pk_phhd}", exist_ok=True)
json_result = {"pk_phhd": pk_phhd}
session = MysqlSession()
phhd = session.query(ZxPhhd.cXm).filter(ZxPhhd.pk_phhd == pk_phhd).one()
json_result["cXm"] = phhd.cXm
settlement = (session.query(ZxIeSettlement.pk_ie_settlement, ZxIeSettlement.name, ZxIeSettlement.admission_date,
ZxIeSettlement.discharge_date, ZxIeSettlement.medical_expenses,
ZxIeSettlement.personal_cash_payment, ZxIeSettlement.personal_account_payment,
ZxIeSettlement.personal_funded_amount, ZxIeSettlement.medical_insurance_type,
ZxIeSettlement.admission_id, ZxIeSettlement.settlement_id)
.filter(ZxIeSettlement.pk_phhd == pk_phhd).one())
settlement_result = settlement._asdict()
json_result["settlement"] = settlement_result
discharge = (session.query(ZxIeDischarge.pk_ie_discharge, ZxIeDischarge.hospital, ZxIeDischarge.pk_yljg,
ZxIeDischarge.department, ZxIeDischarge.pk_ylks, ZxIeDischarge.name, ZxIeDischarge.age,
ZxIeDischarge.admission_date, ZxIeDischarge.discharge_date, ZxIeDischarge.doctor,
ZxIeDischarge.admission_id)
.filter(ZxIeDischarge.pk_phhd == pk_phhd).one())
discharge_result = discharge._asdict()
json_result["discharge"] = discharge_result
cost = session.query(ZxIeCost.pk_ie_cost, ZxIeCost.name, ZxIeCost.admission_date, ZxIeCost.discharge_date,
ZxIeCost.medical_expenses).filter(ZxIeCost.pk_phhd == pk_phhd).one()
cost_result = cost._asdict()
json_result["cost"] = cost_result
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
ZxPhrec.pk_phhd == pk_phhd).all()
for phrec in phrecs:
img_name = phrec.cfjaddress
img_path = ufile.get_private_url(img_name, "drg2015")
if not img_path:
img_path = ufile.get_private_url(img_name)
response = requests.get(img_path)
image = Image.open(BytesIO(response.content)).convert("RGB")
if need_to_annotation:
font_size = image.width * image.height / 200000
font = ImageFont.truetype("./font/simfang.ttf", size=font_size)
ocr = session.query(ZxIeResult.id, ZxIeResult.content, ZxIeResult.rotation_angle, ZxIeResult.x_offset,
ZxIeResult.y_offset).filter(ZxIeResult.pk_phrec == phrec.pk_phrec).all()
if not ocr:
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
for _, group_results in groupby(ocr, key=lambda x: x.id):
draw = ImageDraw.Draw(image)
for ocr_item in group_results:
result = json.loads(ocr_item.content)
rotation_angle = ocr_item.rotation_angle
x_offset = ocr_item.x_offset
y_offset = ocr_item.y_offset
for key in result:
for value in result[key]:
box = value["bbox"][0]
if rotation_angle:
box = image_util.invert_rotate_rectangle(box, (image.width / 2, image.height / 2),
rotation_angle)
if x_offset:
box[0] += x_offset
box[2] += x_offset
if y_offset:
box[1] += y_offset
box[3] += y_offset
draw.rectangle(box, outline="red", width=2) # 绘制矩形
draw.text((box[0], box[1] - font_size), key, fill="blue", font=font) # 在矩形上方绘制文本
draw.text((box[0], box[3]), value["text"], fill="blue", font=font) # 在矩形下方绘制文本
os.makedirs(f"./check_result/{pk_phhd}/{ocr_item.id}", exist_ok=True)
image.save(f"./check_result/{pk_phhd}/{ocr_item.id}/{img_name}")
else:
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
session.close()
# 自定义JSON处理器
def default(obj):
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, datetime.date):
return obj.strftime("%Y-%m-%d")
with open(f"./check_result/{pk_phhd}/result.json", "w", encoding="utf-8") as json_file:
json.dump(json_result, json_file, indent=4, ensure_ascii=False, default=default)
if __name__ == '__main__':
check_ie_result(5640504)

View File

@@ -19,5 +19,5 @@ DB_URL = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}'
SHOW_SQL = False
Engine = create_engine(DB_URL, echo=SHOW_SQL)
Base = declarative_base()
Base = declarative_base(Engine)
MysqlSession = sessionmaker(bind=Engine)

View File

@@ -1,5 +1,5 @@
# coding: utf-8
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary, Text
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary
from sqlalchemy.dialects.mysql import BIT, CHAR, INTEGER, TINYINT, VARCHAR
from db import Base
@@ -56,8 +56,7 @@ class ZxIeCost(Base):
pk_ie_cost = Column(INTEGER(11), primary_key=True, comment='费用明细信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
content = Column(Text, comment='详细内容')
name = Column(String(20), comment='患者姓名')
name = Column(String(30), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
@@ -97,19 +96,19 @@ class ZxIeDischarge(Base):
pk_ie_discharge = Column(INTEGER(11), primary_key=True, comment='出院记录信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
content = Column(Text, comment='详细内容')
hospital = Column(String(200), comment='医院')
content = Column(String(5000), comment='详细内容')
hospital = Column(String(255), comment='医院')
pk_yljg = Column(INTEGER(11), comment='医院主键')
department = Column(String(200), comment='科室')
department = Column(String(255), comment='科室')
pk_ylks = Column(INTEGER(11), comment='科室主键')
name = Column(String(20), comment='患者姓名')
name = Column(String(30), comment='患者姓名')
age = Column(INTEGER(3), comment='年龄')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
discharge_date = Column(Date, comment='出院日期')
doctor = Column(String(20), comment='主治医生')
admission_id = Column(String(20), comment='住院号')
doctor = Column(String(30), comment='主治医生')
admission_id = Column(String(50), comment='住院号')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -123,7 +122,7 @@ class ZxIeResult(Base):
pk_ocr = Column(INTEGER(11), primary_key=True, comment='图片OCR识别主键')
pk_phhd = Column(INTEGER(11), nullable=False, comment='报销单主键')
pk_phrec = Column(INTEGER(11), nullable=False, comment='图片主键')
id = Column(INTEGER(11), nullable=False, comment='识别批次')
id = Column(CHAR(32), nullable=False, comment='识别批次')
cfjaddress = Column(String(200), nullable=False, comment='云存储文件名')
content = Column(String(5000), comment='OCR识别内容')
rotation_angle = Column(INTEGER(11), comment='旋转角度')
@@ -141,7 +140,7 @@ class ZxIeSettlement(Base):
pk_ie_settlement = Column(INTEGER(11), primary_key=True, comment='结算清单信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
name = Column(String(20), comment='患者姓名')
name = Column(String(30), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
@@ -155,9 +154,9 @@ class ZxIeSettlement(Base):
personal_funded_amount_str = Column(String(255), comment='自费金额字符串')
personal_funded_amount = Column(DECIMAL(18, 2), comment='自费金额')
medical_insurance_type_str = Column(String(255), comment='医保类型字符串')
medical_insurance_type = Column(String(10), comment='医保类型')
admission_id = Column(String(20), comment='住院号')
settlement_id = Column(String(30), comment='医保结算单号码')
medical_insurance_type = Column(String(40), comment='医保类型')
admission_id = Column(String(50), comment='住院号')
settlement_id = Column(String(50), comment='医保结算单号码')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -415,17 +414,19 @@ class ZxIeReview(Base):
pk_ie_review = Column(INTEGER(11), primary_key=True, comment='自动审核主键')
pk_phhd = Column(INTEGER(11), nullable=False, comment='报销案子主键')
success = Column(BIT(1))
integrity = Column(BIT(1))
has_settlement = Column(BIT(1))
has_discharge = Column(BIT(1))
has_cost = Column(BIT(1))
full_page = Column(BIT(1))
page_description = Column(String(255), comment='具体缺页描述')
consistency = Column(BIT(1), comment='三项资料一致性。0不一致1一致')
name_match = Column(CHAR(1), server_default=text("'0'"),
comment='三项资料姓名是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致5与报销申请对象不一致')
comment='三项资料姓名是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致5与报销申请对象不一致')
admission_date_match = Column(CHAR(1), server_default=text("'0'"),
comment='三项资料入院日期是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致')
comment='三项资料入院日期是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致')
discharge_date_match = Column(CHAR(1), server_default=text("'0'"),
comment='三项资料出院日期是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致')
comment='三项资料出院日期是否一致。0不一致1一致2结算单不一致3出院记录不一致4费用清单不一致')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),

View File

@@ -1,74 +0,0 @@
# 删除本地数据库中的过期数据
import logging.config
from datetime import datetime, timedelta
from db import MysqlSession
from db.mysql import ZxPhhd, ZxIeCost, ZxIeDischarge, ZxIeResult, ZxIeSettlement, ZxPhrec
from log import LOGGING_CONFIG
# 过期时间(不建议小于1个月)
EXPIRATION_DAYS = 183
# 批量删除数量(最好在1000~10000之间)
BATCH_SIZE = 5000
# 数据库会话对象
session = None
def batch_delete_by_pk_phhd(model, pk_phhds):
"""
批量删除指定模型中主键在指定列表中的数据
参数:
modelSQLAlchemy模型类对应数据库表
pk_phhds待删除的主键值列表
返回:
删除的记录数量
"""
delete_count = (
session.query(model)
.filter(model.pk_phhd.in_(pk_phhds))
.delete(synchronize_session=False)
)
session.commit()
logging.getLogger("sql").info(f"{model.__tablename__}成功删除{delete_count}条数据")
return delete_count
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
deadline = datetime.now() - timedelta(days=EXPIRATION_DAYS)
double_deadline = deadline - timedelta(days=EXPIRATION_DAYS)
session = MysqlSession()
try:
while 1:
# 已经走完所有流程的案子,超过过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.paint_flag == "9")
.filter(ZxPhhd.billdate < deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有通过审核,可能会重拍补拍上传的案子,超过两倍过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.exsuccess_flag == "9")
.filter(ZxPhhd.paint_flag == "0")
.filter(ZxPhhd.billdate < double_deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有符合条件的数据,退出循环
break
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
logging.getLogger("sql").info(f"过期的pk_phhd有{','.join(map(str, pk_phhd_values))}")
batch_delete_by_pk_phhd(ZxPhrec, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeResult, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeSettlement, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeDischarge, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeCost, pk_phhd_values)
batch_delete_by_pk_phhd(ZxPhhd, pk_phhd_values)
except Exception as e:
session.rollback()
logging.getLogger('error').error('过期数据删除失败!', exc_info=e)
finally:
session.close()

View File

@@ -1,32 +0,0 @@
import base64
import cv2
import numpy as np
from flask import Flask, request, jsonify
from paddle_detection import detector
app = Flask(__name__)
@app.route("/det/detect_books", methods=['POST'])
def detect_books():
try:
file = request.files['image']
image_data = file.read()
nparr = np.frombuffer(image_data, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
result = detector.get_book_areas(image)
encoded_images = []
for i in result:
_, encoded_image = cv2.imencode('.jpg', i)
byte_stream = encoded_image.tobytes()
img_str = base64.b64encode(byte_stream).decode('utf-8')
encoded_images.append(img_str)
return jsonify(encoded_images), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run("0.0.0.0")

View File

@@ -1,26 +0,0 @@
services:
fcb_ai_dev:
image: fcb_ai_dev:0.0.10
build:
context: .
dockerfile: Dockerfile.dev
# 容器名称,可自定义
container_name: fcb_ai_dev
hostname: fcb_ai_dev
# 始终重启容器
restart: always
# 端口映射,根据需要修改主机端口
ports:
- "8022:22"
# 数据卷映射,根据实际路径修改
volumes:
- ./log:/app/log
- ./model:/app/model
# 启用GPU支持
deploy:
resources:
reservations:
devices:
- device_ids: [ '0', '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'

View File

@@ -1,46 +1,150 @@
x-env:
&template
image: fcb_photo_review:1.15.7
x-base:
&base_template
restart: always
x-review:
&review_template
<<: *template
x-project:
&project_template
<<: *base_template
image: fcb_photo_review:2.0.0
volumes:
- ./log:/app/log
- ./model:/app/model
deploy:
resources:
reservations:
devices:
- device_ids: [ '0', '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
- ./tmp_img:/app/tmp_img
x-mask:
&mask_template
<<: *template
x-paddle:
&paddle_template
<<: *base_template
image: fcb_paddle:0.0.1
volumes:
- ./log:/app/log
deploy:
resources:
reservations:
devices:
- device_ids: [ '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
- ./services/paddle_services/log:/app/log
- ./services/paddle_services/model:/app/model
- ./tmp_img:/app/tmp_img
services:
ocr:
<<: *paddle_template
build:
context: ./services/paddle_services
container_name: ocr
hostname: ocr
command: [ '-w', '4', 'ocr:app', '--bind', '0.0.0.0:5001' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
ie_settlement:
<<: *paddle_template
container_name: ie_settlement
hostname: ie_settlement
command: [ '-w', '5', 'ie_settlement:app', '--bind', '0.0.0.0:5002' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
ie_discharge:
<<: *paddle_template
container_name: ie_discharge
hostname: ie_discharge
command: [ '-w', '5', 'ie_discharge:app', '--bind', '0.0.0.0:5003' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
ie_cost:
<<: *paddle_template
container_name: ie_cost
hostname: ie_cost
command: [ '-w', '5', 'ie_cost:app', '--bind', '0.0.0.0:5004' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
clas_orientation:
<<: *paddle_template
container_name: clas_orientation
hostname: clas_orientation
command: [ '-w', '3', 'clas_orientation:app', '--bind', '0.0.0.0:5005' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
det_book:
<<: *paddle_template
container_name: det_book
hostname: det_book
command: [ '-w', '4', 'det_book:app', '--bind', '0.0.0.0:5006' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
dewarp:
<<: *paddle_template
container_name: dewarp
hostname: dewarp
command: [ '-w', '4', 'dewarp:app', '--bind', '0.0.0.0:5007' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
# clas_text:
# <<: *paddle_template
# container_name: clas_text
# hostname: clas_text
# command: [ '-w', '3', 'clas_text:app', '--bind', '0.0.0.0:5008' ]
# deploy:
# resources:
# reservations:
# devices:
# - device_ids: [ '1' ]
# capabilities: [ 'gpu' ]
# driver: 'nvidia'
photo_review_1:
<<: *review_template
<<: *project_template
build:
context: .
container_name: photo_review_1
hostname: photo_review_1
depends_on:
- ocr
- ie_settlement
- ie_discharge
- ie_cost
- clas_orientation
- det_book
- dewarp
# - clas_text
command: [ 'photo_review.py', '--clean', 'True' ]
photo_review_2:
<<: *review_template
<<: *project_template
container_name: photo_review_2
hostname: photo_review_2
depends_on:
@@ -48,57 +152,41 @@ services:
command: [ 'photo_review.py' ]
photo_review_3:
<<: *review_template
<<: *project_template
container_name: photo_review_3
hostname: photo_review_3
depends_on:
- photo_review_2
- photo_review_1
command: [ 'photo_review.py' ]
photo_review_4:
<<: *review_template
<<: *project_template
container_name: photo_review_4
hostname: photo_review_4
depends_on:
- photo_review_3
- photo_review_1
command: [ 'photo_review.py' ]
photo_review_5:
<<: *review_template
<<: *project_template
container_name: photo_review_5
hostname: photo_review_5
depends_on:
- photo_review_4
- photo_review_1
command: [ 'photo_review.py' ]
photo_mask_1:
<<: *mask_template
<<: *project_template
container_name: photo_mask_1
hostname: photo_mask_1
depends_on:
- photo_review_5
- photo_review_1
command: [ 'photo_mask.py', '--clean', 'True' ]
photo_mask_2:
<<: *mask_template
<<: *project_template
container_name: photo_mask_2
hostname: photo_mask_2
depends_on:
- photo_mask_1
command: [ 'photo_mask.py' ]
#
# photo_review_6:
# <<: *review_template
# container_name: photo_review_6
# hostname: photo_review_6
# depends_on:
# - photo_mask_2
# command: [ 'photo_review.py' ]
#
# photo_review_7:
# <<: *review_template
# container_name: photo_review_7
# hostname: photo_review_7
# depends_on:
# - photo_review_6
# command: [ 'photo_review.py' ]
command: [ 'photo_mask.py' ]

Binary file not shown.

View File

@@ -1,153 +0,0 @@
# PaddleOCR
------
## 数据集
该部分内容均在PPOCRLabel目录下进行
```bash
# 进入PPOCRLabel目录
cd .\PPOCRLabel\
```
### 打标
可以对PPOCRLabel.py直接使用PyCharm中的Run但是默认是英文的
```bash
# 以中文运行打标应用
python PPOCRLabel.py --lang ch
# 含有关键词提取的打标
python PPOCRLabel.py --lang ch --kie True
```
### 划分数据集
```bash
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data/drivingData
```
------
## 检测模型
先回到项目根目录
### 训练
```bash
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
```
### 测试
```bash
python tools/infer_det.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/1.jpg
```
### 恢复训练
```bash
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.checkpoints=./output/det_v4_bankcard/latest
```
------
## 识别模型
### 训练
```bash
python tools/train.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml
```
### 测试
```bash
python tools/infer_rec.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/crop_img/1_crop_0.jpg
```
------
## 推理模型
### 检测模型转换
```bash
python tools/export_model.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams
```
### 识别模型转换
```bash
python tools/export_model.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams
```
### 检测识别测试
```bash
python tools/infer/predict_system.py --det_model_dir=inference_model/det_v4_bankcard --rec_model_dir=inference_model/rec_v4_bankcard --rec_char_dict_path=ppocr/utils/num_dict.txt --image_dir=train_data/drivingData/1.jpg
```
------
## 移动端模型
### 检测模型转换
```bash
paddle_lite_opt --model_file=inference_model/det_v4_bankcard/inference.pdmodel --param_file=inference_model/det_v4_bankcard/inference.pdiparams --optimize_out=inference_model/det_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
```
### 识别模型转换
```bash
paddle_lite_opt --model_file=inference_model/rec_v4_bankcard/inference.pdmodel --param_file=inference_model/rec_v4_bankcard/inference.pdiparams --optimize_out=inference_model/rec_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
```
------
------
# PaddleNLP
## 数据集
使用Label Studio进行数据标注安装过程省略
```bash
# 打开Anaconda Prompt
# 激活安装Label Studio的环境
conda activate label-studio
# 启动Label Studio
label-studio start
```
[打标流程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/information_extraction/label_studio_doc.md)
### 数据转换
```bash
# 进入PaddleNLP\applications\information_extraction后执行
python label_studio.py --label_studio_file ./document/data/label_studio.json --save_dir ./document/data --splits 0.8 0.1 0.1 --task_type ext
```
------
## 训练模型
```bash
# 进入PaddleNLP\applications\information_extraction\document后执行(双卡训练)
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py --device gpu --logging_steps 5 --save_steps 25 --eval_steps 25 --seed 42 --model_name_or_path uie-x-base --output_dir ./checkpoint/model_best --train_path data/train.txt --dev_path data/dev.txt --max_seq_len 512 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --num_train_epochs 10 --learning_rate 1e-5 --do_train --do_eval --do_export --export_model_dir ./checkpoint/model_best --overwrite_output_dir --disable_tqdm False --metric_for_best_model eval_f1 --load_best_model_at_end True --save_total_limit 1
```
------
参考:
[PaddleOCR训练属于自己的模型详细教程](https://blog.csdn.net/qq_52852432/article/details/131817619?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-0-131817619-blog-124628731.235^v40^pc_relevant_3m_sort_dl_base1&amp;spm=1001.2101.3001.4242.1&amp;utm_relevant_index=3)
[端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/deploy/lite/readme_ch.md)
[PaddleNLP关键信息抽取](https://blog.csdn.net/z5z5z5z56/article/details/130346646)

View File

@@ -1,329 +0,0 @@
anyio 4.0.0
astor 0.8.1
certifi 2019.11.28
chardet 3.0.4
dbus-python 1.2.16
decorator 5.1.1
distro-info 0.23+ubuntu1.1
exceptiongroup 1.1.3
h11 0.14.0
httpcore 1.0.2
httpx 0.25.1
idna 2.8
numpy 1.26.2
opt-einsum 3.3.0
paddlepaddle-gpu 2.6.1.post120
Pillow 10.1.0
pip 24.0
protobuf 4.25.0
PyGObject 3.36.0
python-apt 2.0.1+ubuntu0.20.4.1
requests 2.22.0
requests-unixsocket 0.2.0
setuptools 68.2.2
six 1.14.0
sniffio 1.3.0
unattended-upgrades 0.1
urllib3 1.25.8
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:3.1.0-gpu-cuda12.9-cudnn9.9:
python:3.10.12
Package Version
------------------------ ----------
anyio 4.9.0
certifi 2025.6.15
decorator 5.2.1
exceptiongroup 1.3.0
h11 0.16.0
httpcore 1.0.9
httpx 0.28.1
idna 3.10
networkx 3.4.2
numpy 2.2.6
nvidia-cublas-cu12 12.9.0.13
nvidia-cuda-cccl-cu12 12.9.27
nvidia-cuda-cupti-cu12 12.9.19
nvidia-cuda-nvrtc-cu12 12.9.41
nvidia-cuda-runtime-cu12 12.9.37
nvidia-cudnn-cu12 9.9.0.52
nvidia-cufft-cu12 11.4.0.6
nvidia-cufile-cu12 1.14.0.30
nvidia-curand-cu12 10.3.10.19
nvidia-cusolver-cu12 11.7.4.40
nvidia-cusparse-cu12 12.5.9.5
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.26.5
nvidia-nvjitlink-cu12 12.9.41
nvidia-nvtx-cu12 12.9.19
opt-einsum 3.3.0
paddlepaddle-gpu 3.1.0
pillow 11.2.1
pip 25.1.1
protobuf 6.31.1
setuptools 59.6.0
sniffio 1.3.1
typing_extensions 4.14.0
wheel 0.37.1
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
python:3.10.18
Package Version Editable project location
------------------------- -------------------- -------------------------
aiohappyeyeballs 2.6.1
aiohttp 3.12.13
aiosignal 1.4.0
aistudio_sdk 0.3.5
albucore 0.0.13+pdx
albumentations 1.4.10+pdx
alembic 1.16.2
annotated-types 0.7.0
anyio 4.9.0
astor 0.8.1
asttokens 3.0.0
async-timeout 4.0.3
attrdict3 2.0.2
attrs 25.3.0
babel 2.17.0
bce-python-sdk 0.9.35
beautifulsoup4 4.13.4
blinker 1.9.0
cachetools 6.1.0
certifi 2019.11.28
cffi 1.17.1
chardet 3.0.4
charset-normalizer 3.4.2
chinese-calendar 1.8.0
click 8.2.1
cloudpickle 3.1.1
colorama 0.4.6
colorlog 6.9.0
ConfigSpace 1.2.1
contourpy 1.3.2
cssselect 1.3.0
cssutils 2.11.1
cycler 0.12.1
Cython 3.1.2
dataclasses-json 0.6.7
datasets 3.6.0
dbus-python 1.2.16
decorator 5.2.1
decord 0.6.0
descartes 1.1.0
dill 0.3.4
distro 1.9.0
distro-info 0.23+ubuntu1.1
easydict 1.13
einops 0.8.1
et_xmlfile 2.0.0
exceptiongroup 1.2.2
executing 2.2.0
faiss-cpu 1.8.0.post1
fastapi 0.116.0
filelock 3.18.0
fire 0.7.0
FLAML 2.3.5
Flask 3.1.1
flask-babel 4.0.0
fonttools 4.58.5
frozenlist 1.7.0
fsspec 2025.3.0
ftfy 6.3.1
future 1.0.0
gast 0.3.3
GPUtil 1.4.0
greenlet 3.2.3
h11 0.14.0
h5py 3.14.0
hf-xet 1.1.5
hpbandster 0.7.4
httpcore 1.0.7
httpx 0.28.1
httpx-sse 0.4.1
huggingface-hub 0.33.2
idna 2.8
imageio 2.37.0
imagesize 1.4.1
imgaug 0.4.0+pdx
ipython 8.37.0
itsdangerous 2.2.0
jedi 0.19.2
jieba 0.42.1
Jinja2 3.1.6
jiter 0.10.0
joblib 1.5.1
jsonpatch 1.33
jsonpointer 3.0.0
jsonschema 4.24.0
jsonschema-specifications 2025.4.1
kiwisolver 1.4.8
langchain 0.3.26
langchain-community 0.3.27
langchain-core 0.3.68
langchain-openai 0.3.27
langchain-text-splitters 0.3.8
langsmith 0.4.4
lapx 0.5.11.post1
lazy_loader 0.4
llvmlite 0.44.0
lmdb 1.6.2
lxml 6.0.0
Mako 1.3.10
markdown-it-py 3.0.0
MarkupSafe 3.0.2
marshmallow 3.26.1
matplotlib 3.5.3
matplotlib-inline 0.1.7
mdurl 0.1.2
more-itertools 10.7.0
motmetrics 1.4.0
msgpack 1.1.1
multidict 6.6.3
multiprocess 0.70.12.2
mypy_extensions 1.1.0
netifaces 0.11.0
networkx 3.4.2
numba 0.61.2
numpy 1.24.4
nuscenes-devkit 1.1.11+pdx
onnx 1.17.0
onnxoptimizer 0.3.13
openai 1.93.1
opencv-contrib-python 4.10.0.84
openpyxl 3.1.5
opt-einsum 3.3.0
optuna 4.4.0
orjson 3.10.18
packaging 24.2
paddle2onnx 2.0.2rc3
paddle3d 0.0.0
paddleclas 2.6.0
paddledet 0.0.0
paddlefsl 1.1.0
paddlenlp 2.8.0.post0
paddlepaddle-gpu 3.0.0
paddleseg 0.0.0.dev0
paddlets 1.1.0
paddlex 3.1.2 /root/PaddleX
pandas 1.3.5
parso 0.8.4
patsy 1.0.1
pexpect 4.9.0
pillow 11.1.0
pip 25.1.1
polygraphy 0.49.24
ppvideo 2.3.0
premailer 3.10.0
prettytable 3.16.0
prompt_toolkit 3.0.51
propcache 0.3.2
protobuf 6.30.1
psutil 7.0.0
ptyprocess 0.7.0
pure_eval 0.2.3
py-cpuinfo 9.0.0
pyarrow 20.0.0
pybind11 2.13.6
pybind11-stubgen 2.5.1
pyclipper 1.3.0.post6
pycocotools 2.0.8
pycparser 2.22
pycryptodome 3.23.0
pydantic 2.11.7
pydantic_core 2.33.2
pydantic-settings 2.10.1
Pygments 2.19.2
PyGObject 3.36.0
PyMatting 1.1.14
pyod 2.0.5
pypandoc 1.15
pyparsing 3.2.3
pypdfium2 4.30.1
pyquaternion 0.9.9
Pyro4 4.82
python-apt 2.0.1+ubuntu0.20.4.1
python-dateutil 2.9.0.post0
python-docx 1.2.0
python-dotenv 1.1.1
pytz 2025.2
PyWavelets 1.3.0
PyYAML 6.0.2
RapidFuzz 3.13.0
rarfile 4.2
ray 2.47.1
referencing 0.36.2
regex 2024.11.6
requests 2.32.4
requests-toolbelt 1.0.0
requests-unixsocket 0.2.0
rich 14.0.0
rpds-py 0.26.0
ruamel.yaml 0.18.14
ruamel.yaml.clib 0.2.12
safetensors 0.5.3
scikit-image 0.25.2
scikit-learn 1.3.2
scipy 1.15.3
seaborn 0.13.2
sentencepiece 0.2.0
seqeval 1.2.2
serpent 1.41
setuptools 68.2.2
shap 0.48.0
Shapely 1.8.5.post1
shellingham 1.5.4
six 1.14.0
sklearn 0.0
slicer 0.0.8
sniffio 1.3.1
soundfile 0.13.1
soupsieve 2.7
SQLAlchemy 2.0.41
stack-data 0.6.3
starlette 0.46.2
statsmodels 0.14.1
tenacity 9.1.2
tensorboardX 2.6.4
tensorrt 10.5.0
termcolor 3.1.0
terminaltables 3.1.10
threadpoolctl 3.6.0
tifffile 2025.5.10
tiktoken 0.9.0
tokenizers 0.19.1
tomli 2.2.1
tool_helpers 0.1.2
tqdm 4.67.1
traitlets 5.14.3
typeguard 4.4.4
typer 0.16.0
typing_extensions 4.14.1
typing-inspect 0.9.0
typing-inspection 0.4.1
tzdata 2025.2
ujson 5.10.0
unattended-upgrades 0.1
urllib3 1.25.8
uvicorn 0.35.0
visualdl 2.5.3
Wand 0.6.13
wcwidth 0.2.13
Werkzeug 3.1.3
xmltodict 0.14.2
xxhash 3.5.0
yacs 0.1.8
yarl 1.20.1
zstandard 0.23.0

View File

@@ -1,14 +1,15 @@
import os
import socket
# 项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 获取主机名,方便区分容器
HOSTNAME = socket.gethostname()
# 检测日志文件的路径是否存在,不存在则创建
LOG_PATHS = [
f"log/{HOSTNAME}/ucloud",
f"log/{HOSTNAME}/error",
f"log/{HOSTNAME}/qr",
f"log/{HOSTNAME}/sql",
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'ucloud'),
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'error'),
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'qr'),
]
for path in LOG_PATHS:
if not os.path.exists(path):
@@ -75,16 +76,6 @@ LOGGING_CONFIG = {
'backupCount': 14,
'encoding': 'utf-8',
},
'sql': {
'class': 'logging.handlers.TimedRotatingFileHandler',
'level': 'INFO',
'formatter': 'standard',
'filename': f'log/{HOSTNAME}/sql/fcb_photo_review_sql.log',
'when': 'midnight',
'interval': 1,
'backupCount': 14,
'encoding': 'utf-8',
},
},
# loggers定义了日志记录器
@@ -109,10 +100,5 @@ LOGGING_CONFIG = {
'level': 'DEBUG',
'propagate': False,
},
'sql': {
'handlers': ['console', 'sql'],
'level': 'DEBUG',
'propagate': False,
},
},
}

View File

@@ -8,15 +8,13 @@ MAX_WAIT_TIME = 3
# 程序异常短信配置
ERROR_EMAIL_CONFIG = {
# SMTP服务器地址
"smtp_server": "smtp.163.com",
'smtp_server': 'smtp.163.com',
# 连接SMTP的端口
"port": 994,
'port': 994,
# 发件人邮箱地址请确保开启了SMTP邮件服务
"sender": "EchoLiu618@163.com",
'sender': 'EchoLiu618@163.com',
# 授权码--用于登录第三方邮件客户端的专用密码,不是邮箱密码
"authorization_code": "OKPQLIIVLVGRZYVH",
'authorization_code': 'OKPQLIIVLVGRZYVH',
# 收件人邮箱地址
"receivers": ["1515783401@qq.com"],
# 尝试次数
"retry_times": 3,
'receivers': ['1515783401@qq.com'],
}

View File

@@ -5,18 +5,18 @@ from email.mime.text import MIMEText
from tenacity import retry, stop_after_attempt, wait_random
from auto_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
from log import HOSTNAME
from my_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
@retry(stop=stop_after_attempt(TRY_TIMES), wait=wait_random(MIN_WAIT_TIME, MAX_WAIT_TIME), reraise=True,
after=lambda x: logging.warning("发送邮件失败!"))
after=lambda x: logging.warning('发送邮件失败!'))
def send_email(email_config, massage):
smtp_server = email_config["smtp_server"]
port = email_config["port"]
sender = email_config["sender"]
authorization_code = email_config["authorization_code"]
receivers = email_config["receivers"]
smtp_server = email_config['smtp_server']
port = email_config['port']
sender = email_config['sender']
authorization_code = email_config['authorization_code']
receivers = email_config['receivers']
mail = smtplib.SMTP_SSL(smtp_server, port) # 连接SMTP服务
mail.login(sender, authorization_code) # 登录到SMTP服务
mail.sendmail(sender, receivers, massage.as_string()) # 发送邮件
@@ -34,13 +34,13 @@ def send_error_email(program_name, error_name, error_detail):
"""
# SMTP 服务器配置
sender = ERROR_EMAIL_CONFIG["sender"]
receivers = ERROR_EMAIL_CONFIG["receivers"]
sender = ERROR_EMAIL_CONFIG['sender']
receivers = ERROR_EMAIL_CONFIG['receivers']
# 获取程序出错的时间
error_time = datetime.datetime.strftime(datetime.datetime.today(), "%Y-%m-%d %H:%M:%S:%f")
error_time = datetime.datetime.strftime(datetime.datetime.today(), '%Y-%m-%d %H:%M:%S:%f')
# 邮件内容
subject = f"【程序异常提醒】{program_name}({HOSTNAME}) {error_time}" # 邮件的标题
subject = f'【程序异常提醒】{program_name}({HOSTNAME}) {error_time}' # 邮件的标题
content = f'''<div class="emailcontent" style="width:100%;max-width:720px;text-align:left;margin:0 auto;padding-top:80px;padding-bottom:20px">
<div class="emailtitle">
<h1 style="color:#fff;background:#51a0e3;line-height:70px;font-size:24px;font-weight:400;padding-left:40px;margin:0">程序运行异常通知</h1>

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1,4 +0,0 @@
from onnxruntime import InferenceSession
PADDLE_DET = InferenceSession("model/object_det_model/ppyoloe_plus_crn_l_80e_coco_w_nms.onnx",
providers=["CPUExecutionProvider"], provider_options=[{"device_id": 0}])

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_cn.md

View File

@@ -1 +0,0 @@
README_en.md

View File

@@ -1,76 +0,0 @@
import base64
import logging
import tempfile
from collections import defaultdict
import cv2
import numpy as np
import requests
from tenacity import retry, stop_after_attempt, wait_random
from paddle_detection import PADDLE_DET
from paddle_detection.deploy.third_engine.onnx.infer import PredictConfig
from paddle_detection.deploy.third_engine.onnx.preprocess import Compose
from util import image_util, util
def predict_image(infer_config, predictor, img_path):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
# predict image
inputs = transforms(img_path)
inputs["image"] = np.array(inputs["image"]).astype('float32')
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
bboxes = np.array(outputs[0])
result = defaultdict(list)
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
result[bbox[0]].append({"score": bbox[1], "box": bbox[2:]})
return result
def detect_image(img_path):
infer_cfg = "model/object_det_model/infer_cfg.yml"
# load infer config
infer_config = PredictConfig(infer_cfg)
return predict_image(infer_config, PADDLE_DET, img_path)
def get_book_areas(image):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
detect_result = detect_image(temp_file.name)
util.delete_temp_file(temp_file.name)
book_areas = detect_result[73]
result = []
for book_area in book_areas:
result.append(image_util.capture(image, book_area["box"]))
return result
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
after=lambda x: logging.warning("获取文档区域失败!"))
def request_book_areas(image):
url = "http://det_api:5000/det/detect_books"
_, encoded_image = cv2.imencode('.jpg', image)
byte_stream = encoded_image.tobytes()
files = {"image": ("image.jpg", byte_stream)}
response = requests.post(url, files=files)
if response.status_code == 200:
img_str_list = response.json()
result = []
for img_str in img_str_list:
img_data = base64.b64decode(img_str)
np_array = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
height, width = img.shape[:2]
if max(height, width) / min(height, width) <= 6.5:
result.append(img) # 过滤异常结果
return result
else:
return []

View File

@@ -5,35 +5,36 @@ from time import sleep
from sqlalchemy import update
from auto_email.error_email import send_error_email
from db import MysqlSession
from db.mysql import ZxPhhd
from log import LOGGING_CONFIG
from my_email.error_email import send_error_email
from photo_mask import auto_photo_mask, SEND_ERROR_EMAIL
if __name__ == '__main__':
program_name = "照片审核自动涂抹脚本"
program_name = '照片审核自动涂抹脚本'
logging.config.dictConfig(LOGGING_CONFIG)
logging.info('等待接口服务启动...')
sleep(60)
parser = argparse.ArgumentParser()
parser.add_argument("--clean", default=False, type=bool, help="是否将涂抹中的案子改为待涂抹状态")
parser.add_argument('--clean', default=False, type=bool, help='是否将涂抹中的案子改为待涂抹状态')
args = parser.parse_args()
if args.clean:
# 主要用于启动时,清除仍在涂抹中的案子
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == "2").values(paint_flag="1"))
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == '2').values(paint_flag='1'))
session.execute(update_flag)
session.commit()
session.close()
logging.info("已释放残余的涂抹案子!")
else:
sleep(5)
logging.info('已释放残余的涂抹案子!')
try:
logging.info(f"{program_name}】开始运行")
logging.info(f'{program_name}】开始运行')
auto_photo_mask.main()
except Exception as e:
error_logger = logging.getLogger("error")
error_logger = logging.getLogger('error')
error_logger.error(traceback.format_exc())
if SEND_ERROR_EMAIL:
send_error_email(program_name, repr(e), traceback.format_exc())

View File

@@ -1,5 +1,3 @@
from paddleocr import PaddleOCR
"""
项目配置
"""
@@ -40,13 +38,3 @@ SIMILAR_CHAR = {
"": [""],
"": [""],
}
# 如果不希望识别出空格可以设置use_space_char=False。做此项设置一定要测试2.7.3版本此项设置有bug会导致识别失败
OCR = PaddleOCR(
gpu_id=0,
show_log=False,
det_db_thresh=0.1,
det_db_box_thresh=0.3,
det_limit_side_len=1248,
drop_score=0.3
)

View File

@@ -1,7 +1,9 @@
import logging.config
import os
import re
import tempfile
import shutil
import time
import uuid
from time import sleep
import cv2
@@ -10,9 +12,10 @@ from sqlalchemy import update, and_
from db import MysqlSession
from db.mysql import ZxPhrec, ZxPhhd
from log import HOSTNAME
from photo_mask import OCR, PHHD_BATCH_SIZE, SLEEP_MINUTES, NAME_KEYS, ID_CARD_NUM_KEYS, SIMILAR_CHAR
from photo_mask import PHHD_BATCH_SIZE, SLEEP_MINUTES, NAME_KEYS, ID_CARD_NUM_KEYS, SIMILAR_CHAR
from photo_review import set_batch_id
from ucloud import BUCKET, ufile
from util import image_util, util
from util import image_util, common_util, model_util
def find_boxes(content, layout, offset=0, length=None, improve=False, image_path=None, extra_content=None):
@@ -55,14 +58,15 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
if improve:
# 再次识别,提高精度
image = cv2.imread(image_path)
img_name, img_ext = common_util.parse_save_path(image_path)
# 截图时偏大一点
capture_box = util.zoom_rectangle(box, 0.2)
capture_box = common_util.zoom_rectangle(box, 0.2)
captured_image = image_util.capture(image, capture_box)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
captured_image, offset_x, offset_y = image_util.expand_to_a4_size(captured_image)
cv2.imwrite(temp_file.name, captured_image)
captured_image_path = common_util.get_processed_img_path(f'{img_name}.capture.{img_ext}')
cv2.imwrite(captured_image_path, captured_image)
captured_a4_img_path, offset_x, offset_y = image_util.expand_to_a4_size(captured_image_path)
try:
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
layouts = common_util.ocr_result_to_layout(model_util.ocr(captured_a4_img_path))
except TypeError:
# 如果是类型错误,大概率是没识别到文字
layouts = []
@@ -86,22 +90,17 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
temp_box[3] + capture_box[1] - offset_y,
])
break
util.delete_temp_file(temp_file.name)
if not boxes:
boxes.append(box)
return boxes
def get_mask_layout(image, name, id_card_num):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
def get_mask_layout(img_path, name, id_card_num):
result = []
try:
try:
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
# layouts = OCR.parse({"doc": temp_file.name})["layout"]
layouts = common_util.ocr_result_to_layout(model_util.ocr(img_path))
except TypeError:
# 如果是类型错误,大概率是没识别到文字
layouts = []
@@ -135,12 +134,12 @@ def get_mask_layout(image, name, id_card_num):
find_id_card_num_by_key = True
matches = re.findall(r, layout[1])
for match in matches:
result += find_boxes(match, layout, improve=True, image_path=temp_file.name, extra_content=r)
result += find_boxes(match, layout, improve=True, image_path=img_path, extra_content=r)
find_name_by_key = False
break
if id_card_num in layout[1]:
result += find_boxes(id_card_num, layout, improve=True, image_path=temp_file.name)
result += find_boxes(id_card_num, layout, improve=True, image_path=img_path)
find_id_card_num_by_key = False
def _find_boxes_by_keys(keys):
@@ -160,13 +159,9 @@ def get_mask_layout(image, name, id_card_num):
result += _find_boxes_by_keys(ID_CARD_NUM_KEYS)
return result
except MemoryError as e:
raise e
except Exception as e:
logging.error("涂抹时出错!", exc_info=e)
return result
finally:
util.delete_temp_file(temp_file.name)
def handle_image_for_mask(split_result):
@@ -176,11 +171,15 @@ def handle_image_for_mask(split_result):
return expand_img, split_result["x_offset"], split_result["y_offset"]
def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
def _mask(i, n, icn, c):
def mask_photo(img_path, name, id_card_num, color=(255, 255, 255)):
def _mask(ip, n, icn, c):
i = cv2.imread(ip)
img_name, img_ext = common_util.parse_save_path(ip)
do_mask = False
split_results = image_util.split(i)
split_results = image_util.split(ip)
for split_result in split_results:
if not split_result['img']:
continue
to_mask_img, x_offset, y_offset = handle_image_for_mask(split_result)
results = get_mask_layout(to_mask_img, n, icn)
@@ -195,29 +194,27 @@ def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
result[3] + y_offset,
)
cv2.rectangle(i, (int(result[0]), int(result[1])), (int(result[2]), int(result[3])), c, -1, 0)
return do_mask, i
masked_path = common_util.get_processed_img_path(f'{img_name}.mask.{img_ext}')
cv2.imwrite(masked_path, i)
return do_mask, masked_path
# 打开图片
image, _ = image_util.read(img_url)
if image is None:
return False, image
original_image = image
is_masked, image = _mask(image, name, id_card_num, color)
original_image = img_path
is_masked, img_path = _mask(img_path, name, id_card_num, color)
if not is_masked:
# 如果没有涂抹,可能是图片方向不对
angles = image_util.parse_rotation_angles(image)
angles = model_util.clas_orientation(img_path)
angle = angles[0]
if angle != "0":
image = image_util.rotate(image, int(angle))
is_masked, image = _mask(image, name, id_card_num, color)
img_path = image_util.rotate(img_path, int(angle))
is_masked, img_path = _mask(img_path, name, id_card_num, color)
if not is_masked:
# 如果旋转后也没有涂抹,恢复原来的方向
image = original_image
img_path = original_image
else:
# 如果旋转有效果,打一个日志
logging.info(f"图片旋转了{angle}°")
return is_masked, image
return is_masked, img_path
def photo_mask(pk_phhd, name, id_card_num):
@@ -227,32 +224,37 @@ def photo_mask(pk_phhd, name, id_card_num):
ZxPhrec.cRectype.in_(["3", "4"])
)).all()
session.close()
# 同一批图的标识
set_batch_id(uuid.uuid4().hex)
processed_img_dir = common_util.get_processed_img_path('')
os.makedirs(processed_img_dir, exist_ok=True)
for phrec in phrecs:
img_url = ufile.get_private_url(phrec.cfjaddress)
if not img_url:
continue
is_masked, image = mask_photo(img_url, name, id_card_num)
original_img_path = common_util.save_to_local(img_url)
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
shutil.copy2(original_img_path, img_path)
is_masked, image = mask_photo(img_path, name, id_card_num)
# 如果涂抹了要备份以及更新
if is_masked:
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
try:
ufile.upload_file(phrec.cfjaddress, temp_file.name)
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
ufile.upload_file(phrec.cfjaddress, image)
session = MysqlSession()
update_flag = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
paint_user=HOSTNAME,
paint_date=util.get_default_datetime()))
paint_date=common_util.get_default_datetime()))
session.execute(update_flag)
session.commit()
session.close()
except Exception as e:
logging.error("上传图片出错", exc_info=e)
finally:
util.delete_temp_file(temp_file.name)
# 删除多余图片
if os.path.exists(processed_img_dir) and os.path.isdir(processed_img_dir):
shutil.rmtree(processed_img_dir)
def main():
@@ -282,7 +284,7 @@ def main():
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(
paint_flag="8",
paint_user=HOSTNAME,
paint_date=util.get_default_datetime(),
paint_date=common_util.get_default_datetime(),
fZcfwfy=time.time() - start_time))
session.execute(update_flag)
session.commit()

View File

@@ -8,7 +8,7 @@ from db import MysqlSession
from db.mysql import ZxIeOcrerror, ZxPhrec
from photo_mask.auto_photo_mask import mask_photo
from ucloud import ufile
from util import image_util, util
from util import image_util, common_util
def check_error(error_ocr):
@@ -23,7 +23,7 @@ def check_error(error_ocr):
image = mask_photo(img_url, name, id_card_num, (0, 0, 0))[1]
final_img_url = ufile.get_private_url(error_ocr.cfjaddress, "drg100")
final_image, _ = image_util.read(final_img_url)
final_image = image_util.read(final_img_url)
return image_util.combined(final_image, image)
@@ -91,7 +91,7 @@ if __name__ == '__main__':
session = MysqlSession()
update_error = (update(ZxIeOcrerror).where(ZxIeOcrerror.pk_phrec == ocr_error.pk_phrec).values(
checktime=util.get_default_datetime(), cfjaddress2=error_descript))
checktime=common_util.get_default_datetime(), cfjaddress2=error_descript))
session.execute(update_error)
session.commit()
session.close()

View File

@@ -7,7 +7,7 @@ from sqlalchemy import update, and_
from db import MysqlSession
from db.mysql import ZxIeOcrerror
from photo_mask.photo_mask_error_check import auto_check_error
from util import util
from util import common_util
if __name__ == '__main__':
today = date.today()
@@ -29,7 +29,7 @@ if __name__ == '__main__':
if error_descript == "未知错误":
check_time = None
else:
check_time = util.get_default_datetime()
check_time = common_util.get_default_datetime()
session = MysqlSession()
update_error = (update(ZxIeOcrerror).where(ZxIeOcrerror.pk_phrec == ocr_error.pk_phrec).values(
@@ -41,5 +41,5 @@ if __name__ == '__main__':
print(result)
with open("photo_mask_error_report.txt", 'w', encoding='utf-8') as file:
file.write(json.dumps(result, indent=4, ensure_ascii=False))
file.write(util.get_default_datetime())
file.write(common_util.get_default_datetime())
print("结果已保存。")

View File

@@ -5,36 +5,36 @@ from time import sleep
from sqlalchemy import update
from auto_email.error_email import send_error_email
from db import MysqlSession
from db.mysql import ZxPhhd
from log import LOGGING_CONFIG
from my_email.error_email import send_error_email
from photo_review import auto_photo_review, SEND_ERROR_EMAIL
# 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生
# 照片审核自动识别脚本入口
if __name__ == '__main__':
program_name = "照片审核自动识别脚本"
program_name = '照片审核自动识别脚本'
logging.config.dictConfig(LOGGING_CONFIG)
logging.info('等待接口服务启动...')
sleep(60)
parser = argparse.ArgumentParser()
parser.add_argument("--clean", default=False, type=bool, help="是否将识别中的案子改为待识别状态")
parser.add_argument('--clean', default=False, type=bool, help='是否将识别中的案子改为待识别状态')
args = parser.parse_args()
if args.clean:
# 主要用于启动时清除仍在识别中的案子
# 启动时清除仍在识别中的案子
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == "2").values(exsuccess_flag="1"))
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == '2').values(exsuccess_flag='1'))
session.execute(update_flag)
session.commit()
session.close()
logging.info("已释放残余的识别案子!")
else:
sleep(5)
logging.info('已释放残余的识别案子!')
try:
logging.info(f"{program_name}】开始运行")
logging.info(f'{program_name}】开始运行')
auto_photo_review.main()
except Exception as e:
error_logger = logging.getLogger("error")
error_logger.error(traceback.format_exc())
logging.getLogger('error').error(traceback.format_exc())
if SEND_ERROR_EMAIL:
send_error_email(program_name, repr(e), traceback.format_exc())

View File

@@ -1,112 +1,77 @@
import jieba
from paddlenlp import Taskflow
from paddleocr import PaddleOCR
"""
'''
项目配置
"""
'''
# 每次从数据库获取的案子数量
PHHD_BATCH_SIZE = 10
# 没有查询到案子的等待时间(分钟)
SLEEP_MINUTES = 5
# 是否发送报错邮件
SEND_ERROR_EMAIL = True
# 是否开启布局分析
LAYOUT_ANALYSIS = False
# 处理批号(这里仅起声明作用)
BATCH_ID = ''
"""
信息抽取关键词配置
"""
# 患者姓名
PATIENT_NAME = ["患者姓名"]
# 入院日期
ADMISSION_DATE = ["入院日期"]
# 出院日期
DISCHARGE_DATE = ["出院日期"]
# 发生医疗费
MEDICAL_EXPENSES = ["费用总额"]
# 个人现金支付
PERSONAL_CASH_PAYMENT = ["个人现金支付"]
# 个人账户支付
PERSONAL_ACCOUNT_PAYMENT = ["个人账户支付"]
# 个人自费金额
PERSONAL_FUNDED_AMOUNT = ["自费金额", "个人自费"]
# 医保类别
MEDICAL_INSURANCE_TYPE = ["医保类型"]
# 就诊医院
HOSPITAL = ["医院"]
# 就诊科室
DEPARTMENT = ["科室"]
# 主治医生
DOCTOR = ["主治医生"]
# 住院号
ADMISSION_ID = ["住院号"]
# 医保结算单号码
SETTLEMENT_ID = ["医保结算单号码"]
# 年龄
AGE = ["年龄"]
# 大写总额
UPPERCASE_MEDICAL_EXPENSES = ["大写总额"]
SETTLEMENT_LIST_SCHEMA = \
(PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES + PERSONAL_CASH_PAYMENT
+ PERSONAL_ACCOUNT_PAYMENT + PERSONAL_FUNDED_AMOUNT + MEDICAL_INSURANCE_TYPE + ADMISSION_ID + SETTLEMENT_ID
+ UPPERCASE_MEDICAL_EXPENSES)
DISCHARGE_RECORD_SCHEMA = \
HOSPITAL + DEPARTMENT + PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + DOCTOR + ADMISSION_ID + AGE
COST_LIST_SCHEMA = PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES
"""
'''
别名配置
"""
'''
# 使用别名中的value替换key。考虑到效率问题只会替换第一个匹配到的key。
HOSPITAL_ALIAS = {
"沐阳": ["沭阳"],
"连水": ["涟水"],
"唯宁": ["睢宁"], # 雕宁
"九〇四": ["904"],
"漂水": ["溧水"],
'沐阳': ['沭阳'],
'连水': ['涟水'],
'唯宁': ['睢宁'], # 雕宁
'九〇四': ['904'],
'漂水': ['溧水'],
}
DEPARTMENT_ALIAS = {
"耳鼻喉": ["耳鼻咽喉"],
"急症": ["急诊"],
'耳鼻喉': ['耳鼻咽喉'],
'急症': ['急诊'],
}
"""
'''
搜索过滤配置
"""
'''
# 默认会过滤单字
HOSPITAL_FILTER = ["医院", "人民", "第一", "第二", "第三", "大学", "附属"]
HOSPITAL_FILTER = ['医院', '人民', '第一', '第二', '第三', '大学', '附属']
DEPARTMENT_FILTER = ["", "", "西", ""]
DEPARTMENT_FILTER = ['', '', '西', '']
"""
'''
分词配置
"""
jieba.suggest_freq(("肿瘤", "医院"), True)
jieba.suggest_freq(("", ""), True)
jieba.suggest_freq(("感染", ""), True)
jieba.suggest_freq(("", ""), True)
jieba.suggest_freq(("", ""), True)
'''
jieba.suggest_freq(('肿瘤', '医院'), True)
jieba.suggest_freq(('', ''), True)
jieba.suggest_freq(('感染', ''), True)
jieba.suggest_freq(('', ''), True)
jieba.suggest_freq(('', ''), True)
"""
模型配置
"""
SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
task_path="model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
task_path="model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", device_id=1,
task_path="model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
'''
出院记录缺页判断关键词配置
'''
DISCHARGE_KEY = {
'入院诊断': ['入院诊断'],
'入院情况': ['入院情况', '入院时情况', '入院时主要症状'],
'入院日期': ['入院日期', '入院时间'],
'诊疗经过': ['诊疗经过', '住院经过', '治疗经过'],
'出院诊断': ['出院诊断'],
'出院情况': ['出院情况', '出院时情况'],
'出院日期': ['出院日期', '出院时间'],
'出院医嘱': ['出院医嘱', '出院医瞩']
}
OCR = PaddleOCR(
device="gpu:0",
ocr_version="PP-OCRv4",
use_textline_orientation=False,
# 检测像素阈值,输出的概率图中,得分大于该阈值的像素点才会被认为是文字像素点
text_det_thresh=0.1,
# 检测框阈值,检测结果边框内,所有像素点的平均得分大于该阈值时,该结果会被认为是文字区域
text_det_box_thresh=0.3,
)
def get_batch_id():
"""
获取处理批号
:return: 处理批号
"""
return BATCH_ID
def set_batch_id(batch_id):
"""
修改处理批号哦
:param batch_id: 新批号
"""
global BATCH_ID
BATCH_ID = batch_id

View File

@@ -1,120 +1,76 @@
import copy
import json
import logging
import os
import tempfile
import re
import shutil
import time
import uuid
from collections import defaultdict
from time import sleep
import cv2
import fitz
import jieba
import numpy as np
import requests
import zxingcpp
from rapidfuzz import process, fuzz
from sqlalchemy import update
from db import MysqlSession
from db.mysql import BdYljg, BdYlks, ZxIeResult, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec
from db.mysql import BdYljg, BdYlks, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec, ZxIeReview, ZxIeResult
from log import HOSTNAME
from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES, \
UPPERCASE_MEDICAL_EXPENSES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, DEPARTMENT_FILTER
from photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, \
DEPARTMENT_FILTER, DISCHARGE_KEY, set_batch_id, get_batch_id
from services.paddle_services import IE_KEY
from ucloud import ufile, BUCKET
from util import image_util, util, html_util
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_age, parse_money, \
parse_hospital, handle_doctor, handle_text, handle_admission_id, handle_settlement_id
from util import image_util, common_util, html_util, model_util
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, handle_insurance_type, \
handle_original_data, handle_hospital, handle_department, handle_id, handle_age, parse_money, parse_hospital, \
parse_page_num, handle_tiny_int
# 合并信息抽取结果
def merge_result(result1, result2):
for key in result2:
result1[key] = result1.get(key, []) + result2[key]
return result1
def parse_qrcode(img_path, image_id):
"""
解析二维码,尝试从中获取高清图片
:param img_path: 待解析图片
:param image_id: 图片id
:return: 解析结果
"""
def ie_temp_image(ie, ocr, image, is_screenshot=False):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
ie_result = []
ocr_pure_text = ''
angle = '0'
try:
layout, angle = util.get_ocr_layout(ocr, temp_file.name, is_screenshot)
if not layout:
# 无识别结果
ie_result = []
else:
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
for lay in layout:
ocr_pure_text += lay[1]
except MemoryError as e:
# 显存不足时应该抛出错误,让程序重启,同时释放显存
raise e
except Exception as e:
logging.error("信息抽取时出错", exc_info=e)
finally:
try:
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
return ie_result, ocr_pure_text, angle
# 关键信息提取
def request_ie_result(task_enum, phrecs):
url = task_enum.request_url()
identity = int(time.time())
images = []
for phrec in phrecs:
images.append({"name": phrec.cfjaddress, "pk": phrec.pk_phrec})
payload = {"images": images, "schema": task_enum.schema(), "pk_phhd": phrecs[0].pk_phhd, "identity": identity}
response = requests.post(url, json=payload)
if response.status_code == 200:
return response.json()["data"]
else:
raise Exception(f"请求信息抽取结果失败,状态码:{response.status_code}")
# 尝试从二维码中获取高清图片
def get_better_image_from_qrcode(image, image_id, dpi=150):
def _parse_pdf_url(pdf_url_to_parse):
pdf_file = None
local_pdf_path = None
img_name, img_ext = common_util.parse_save_path(img_path)
try:
local_pdf_path = html_util.download_pdf(pdf_url_to_parse)
# 打开PDF文件
pdf_file = fitz.open(local_pdf_path)
# 选择第一页
page = pdf_file[0]
# 定义缩放系数DPI
default_dpi = 72
zoom = dpi / default_dpi
# 设置矩阵变换参数
mat = fitz.Matrix(zoom, zoom)
# 渲染页面
pix = page.get_pixmap(matrix=mat)
# 将渲染结果转换为OpenCV兼容的格式
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape((pix.height, pix.width, -1))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img, page.get_text()
pdf_imgs = image_util.pdf_to_imgs(local_pdf_path)
# 结算单部分
better_settlement_path = common_util.get_processed_img_path(f'{img_name}.better_settlement.jpg')
cv2.imwrite(better_settlement_path, pdf_imgs[0][0])
# 费用清单部分
better_cost_path = common_util.get_processed_img_path(f'{img_name}.better_cost.jpg')
total_height = sum([p[0].shape[0] for p in pdf_imgs[1:]])
common_width = pdf_imgs[1][0].shape[1]
better_cost_img = np.zeros((total_height, common_width, 3), dtype=np.uint8)
current_y = 0
for pdf in pdf_imgs[1:]:
height = pdf[0].shape[0]
better_cost_img[current_y:current_y + height, :, :] = pdf[0]
current_y += height
# cost_text += pdf[1] # 费用清单文本暂时没用到
cv2.imwrite(better_cost_path, better_cost_img)
return better_settlement_path, pdf_imgs[0][1], better_cost_path
except Exception as ex:
logging.getLogger('error').error('解析pdf失败', exc_info=ex)
return None, None
return None, None, None
finally:
if pdf_file:
pdf_file.close()
if local_pdf_path:
util.delete_temp_file(local_pdf_path)
common_util.delete_temp_file(local_pdf_path)
jsczt_base_url = 'http://einvoice.jsczt.cn'
try:
results = zxingcpp.read_barcodes(image)
img = cv2.imread(img_path)
results = zxingcpp.read_barcodes(img, text_mode=zxingcpp.TextMode.HRI)
except Exception as e:
logging.getLogger('error').info('二维码识别失败', exc_info=e)
results = []
@@ -139,152 +95,122 @@ def get_better_image_from_qrcode(image, image_id, dpi=150):
if not pdf_url:
continue
return _parse_pdf_url(pdf_url)
elif url.startswith('http://weixin.qq.com'):
elif (url.startswith('http://weixin.qq.com')
or url == 'https://ybj.jszwfw.gov.cn/hsa-app-panel/index.html'):
# 无效地址
continue
elif url.startswith('http://dzpj.ntzyy.com'):
# 南通市中医院
return _parse_pdf_url(url)
# elif url.startswith('https://apph5.ztejsapp.cn/nj/view/elecInvoiceForOther/QRCode2Invoice'):
# pdf_url = html_util.get_dtsrmyy_pdf_url(url)
# if not pdf_url:
# continue
# return _parse_pdf_url(pdf_url)
else:
logging.getLogger('qr').info(f'[{image_id}]中有未知二维码内容:{url}')
except Exception as e:
logging.getLogger('error').error('从二维码中获取高清图片时出错', exc_info=e)
continue
return None, None
return None, None, None
# 关键信息提取
def information_extraction(ie, phrecs, identity):
result = {}
def information_extraction(phrec, pk_phhd):
"""
处理单张图片
:param phrec:图片信息
:param pk_phhd:案子主键
:return:记录类型,信息抽取结果
"""
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
if not os.path.exists(img_path):
original_img_path = common_util.get_img_path(phrec.cfjaddress)
if not original_img_path:
img_url = ufile.get_private_url(phrec.cfjaddress)
if not img_url:
return None, None, None
original_img_path = common_util.save_to_local(img_url)
shutil.copy2(original_img_path, img_path)
if image_util.is_photo(img_path):
book_img_path = model_util.det_book(img_path) # 识别文档区域并裁剪
dewarped_img_path = model_util.dewarp(book_img_path) # 去扭曲
else: # todo:也可能是图片,后续添加细分逻辑
dewarped_img_path = img_path
angles = model_util.clas_orientation(dewarped_img_path)
ocr_text = ''
for phrec in phrecs:
img_path = ufile.get_private_url(phrec.cfjaddress)
if not img_path:
continue
image, exif_data = image_util.read(img_path)
if image is None:
# 图片可能因为某些原因获取不到
continue
# 尝试从二维码中获取高清图片
better_image, text = get_better_image_from_qrcode(image, phrec.cfjaddress)
if phrec.cRectype != '1':
better_image = None # 非结算单暂时不进行替换
zx_ie_results = []
if better_image is not None:
img_angle = '0'
image = better_image
if text:
info_extract = ie(text)[0]
else:
info_extract = ie_temp_image(ie, OCR, image, True)[0]
ie_result = {'result': info_extract, 'angle': '0'}
now = util.get_default_datetime()
if not ie_result['result']:
info_extract = []
rec_type = None
for angle in angles:
ocr_result = []
rotated_img = image_util.rotate(dewarped_img_path, int(angle))
split_results = image_util.split(rotated_img)
for split_result in split_results:
if split_result['img'] is None:
continue
a4_img = image_util.expand_to_a4_size(split_result['img'])
tmp_ocr_result = model_util.ocr(a4_img)
if tmp_ocr_result:
ocr_result += tmp_ocr_result
tmp_ocr_text = common_util.ocr_result_to_text(ocr_result)
result_json = json.dumps(ie_result['result'], ensure_ascii=False)
if len(result_json) > 5000:
result_json = result_json[:5000]
zx_ie_results.append(ZxIeResult(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity,
cfjaddress=phrec.cfjaddress, content=result_json,
rotation_angle=int(ie_result['angle']),
x_offset=0, y_offset=0, create_time=now,
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
result = merge_result(result, ie_result['result'])
# if any(key in tmp_ocr_text for key in ['出院记录', '出院小结', '死亡记录']):
# tmp_rec_type = '出院记录'
# elif any(key in tmp_ocr_text for key in ['费用汇总清单', '费用清单', '费用明细', '结账清单', '费用小项统计']):
# tmp_rec_type = '费用清单'
# elif any(key in tmp_ocr_text for key in ['住院收费票据', '结算单', '财政部监制', '结算凭证']):
# tmp_rec_type = '基本医保结算单'
# else:
# tmp_rec_type = model_util.clas_text(tmp_ocr_text) if tmp_ocr_text else None
# if not tmp_rec_type:
rec_dict = {
'1': '基本医保结算单',
'3': '出院记录',
'4': '费用清单',
}
tmp_rec_type = rec_dict.get(phrec.cRectype)
if tmp_rec_type == '基本医保结算单':
tmp_info_extract = model_util.ie_settlement(rotated_img, common_util.ocr_result_to_layout(ocr_result))
elif tmp_rec_type == '出院记录':
tmp_info_extract = model_util.ie_discharge(rotated_img, common_util.ocr_result_to_layout(ocr_result))
elif tmp_rec_type == '费用清单':
tmp_info_extract = model_util.ie_cost(rotated_img, common_util.ocr_result_to_layout(ocr_result))
else:
is_screenshot = image_util.is_screenshot(image, exif_data)
target_images = []
# target_images += detector.request_book_areas(image) # 识别文档区域并裁剪
if not target_images:
target_images.append(image) # 识别失败
angle_count = defaultdict(int, {'0': 0}) # 分割后图片的最优角度统计
for target_image in target_images:
split_results = image_util.split(target_image)
for split_result in split_results:
if split_result['img'] is None or split_result['img'].size == 0:
continue
ie_temp_result = ie_temp_image(ie, OCR, split_result['img'], is_screenshot)
ocr_text += ie_temp_result[1]
ie_results = [{'result': ie_temp_result[0], 'angle': ie_temp_result[2]}]
now = util.get_default_datetime()
best_angle = ['0', 0]
for ie_result in ie_results:
if not ie_result['result']:
continue
tmp_info_extract = []
result_json = json.dumps(ie_result['result'], ensure_ascii=False)
if len(result_json) > 5000:
result_json = result_json[:5000]
zx_ie_results.append(ZxIeResult(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity,
cfjaddress=phrec.cfjaddress, content=result_json,
rotation_angle=int(ie_result['angle']),
x_offset=split_result['x_offset'],
y_offset=split_result['y_offset'], create_time=now,
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
if tmp_info_extract and len(tmp_info_extract) > len(info_extract):
info_extract = tmp_info_extract
ocr_text = tmp_ocr_text
rec_type = tmp_rec_type
result = merge_result(result, ie_result['result'])
if len(ie_result['result']) > best_angle[1]:
best_angle = [ie_result['angle'], len(ie_result['result'])]
angle_count[best_angle[0]] += 1
img_angle = max(angle_count, key=angle_count.get)
if img_angle != '0' or better_image is not None:
image = image_util.rotate(image, int(img_angle))
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
cv2.imwrite(temp_file.name, image)
try:
if img_angle != '0':
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'旋转图片[{phrec.cfjaddress}]替换成功,已旋转{img_angle}度。')
# 修正旋转角度
for zx_ie_result in zx_ie_results:
zx_ie_result.rotation_angle -= int(img_angle)
else:
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'高清图片[{phrec.cfjaddress}]替换成功!')
except Exception as e:
logging.error(f'上传图片({phrec.cfjaddress})失败', exc_info=e)
finally:
util.delete_temp_file(temp_file.name)
if info_extract:
result_json = json.dumps(info_extract, ensure_ascii=False)
if len(result_json) > 5000:
result_json = result_json[:5000]
now = common_util.get_default_datetime()
session = MysqlSession()
session.add_all(zx_ie_results)
session.add(ZxIeResult(pk_phhd=pk_phhd, pk_phrec=phrec.pk_phrec, id=get_batch_id(),
cfjaddress=phrec.cfjaddress, content=result_json, create_time=now,
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
session.commit()
# # 添加清晰度测试
# if better_image is None:
# # 替换后图片默认清晰
# clarity_result = image_util.parse_clarity(image)
# unsharp_flag = 0 if (clarity_result[0] == 0 and clarity_result[1] >= 0.8) else 1
# update_clarity = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
# cfjaddress2=json.dumps(clarity_result),
# unsharp_flag=unsharp_flag,
# ))
# session.execute(update_clarity)
# session.commit()
# session.close()
result['ocr_text'] = ocr_text
return result
session.close()
return rec_type, info_extract, ocr_text
# 从keys中获取准确率最高的value
def get_best_value_in_keys(source, keys):
def get_best_value_of_key(source, key):
# 最终结果
result = None
# 最大可能性
best_probability = 0
for key in keys:
values = source.get(key)
if values:
for value in values:
text = value.get("text")
probability = value.get("probability")
values = source.get(key)
if values:
for value in values:
for v in value:
text = v.get("text")
probability = v.get("probability")
if text and probability > best_probability:
result = text
best_probability = probability
@@ -292,11 +218,11 @@ def get_best_value_in_keys(source, keys):
# 从keys中获取所有value组成list
def get_values_of_keys(source, keys):
def get_values_of_key(source, key):
result = []
for key in keys:
value = source.get(key)
if value:
values = source.get(key)
if values:
for value in values:
for v in value:
v = v.get("text")
if v:
@@ -310,11 +236,11 @@ def save_or_update_ie(table, pk_phhd, data):
obj = table(**data)
session = MysqlSession()
db_data = session.query(table).filter_by(pk_phhd=pk_phhd).one_or_none()
now = util.get_default_datetime()
now = common_util.get_default_datetime()
if db_data:
# 更新
db_data.update_time = now
db_data.updater = HOSTNAME
db_data.creator = HOSTNAME
for k, v in data.items():
setattr(db_data, k, v)
else:
@@ -385,23 +311,24 @@ def search_department(department):
return best_match
def settlement_task(pk_phhd, settlement_list, identity):
settlement_list_ie_result = information_extraction(SETTLEMENT_IE, settlement_list, identity)
def settlement_task(pk_phhd, settlement_list_ie_result):
settlement_data = {
"pk_phhd": pk_phhd,
"name": handle_name(get_best_value_in_keys(settlement_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, DISCHARGE_DATE)),
"name": handle_name(get_best_value_of_key(settlement_list_ie_result, IE_KEY['name'])),
"admission_date_str": handle_original_data(
get_best_value_of_key(settlement_list_ie_result, IE_KEY['admission_date'])),
"discharge_date_str": handle_original_data(
get_best_value_of_key(settlement_list_ie_result, IE_KEY['discharge_date'])),
"personal_cash_payment_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_CASH_PAYMENT)),
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_cash_payment'])),
"personal_account_payment_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_ACCOUNT_PAYMENT)),
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_account_payment'])),
"personal_funded_amount_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_funded_amount'])),
"medical_insurance_type_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE)),
"admission_id": handle_admission_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
"settlement_id": handle_settlement_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
get_best_value_of_key(settlement_list_ie_result, IE_KEY['medical_insurance_type'])),
"admission_id": handle_id(get_best_value_of_key(settlement_list_ie_result, IE_KEY['admission_id'])),
"settlement_id": handle_id(get_best_value_of_key(settlement_list_ie_result, IE_KEY['settlement_id'])),
}
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
@@ -411,32 +338,30 @@ def settlement_task(pk_phhd, settlement_list, identity):
settlement_data["personal_funded_amount"] = handle_decimal(settlement_data["personal_funded_amount_str"])
settlement_data["medical_insurance_type"] = handle_insurance_type(settlement_data["medical_insurance_type_str"])
parse_money_result = parse_money(get_best_value_in_keys(settlement_list_ie_result, UPPERCASE_MEDICAL_EXPENSES),
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES))
parse_money_result = parse_money(
get_best_value_of_key(settlement_list_ie_result, IE_KEY['uppercase_medical_expenses']),
get_best_value_of_key(settlement_list_ie_result, IE_KEY['medical_expenses']))
settlement_data["medical_expenses_str"] = handle_original_data(parse_money_result[0])
settlement_data["medical_expenses"] = parse_money_result[1]
if not settlement_data["settlement_id"]:
# 如果没有结算单号就填住院号
settlement_data["settlement_id"] = settlement_data["admission_id"]
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
return settlement_data
def discharge_task(pk_phhd, discharge_record, identity):
discharge_record_ie_result = information_extraction(DISCHARGE_IE, discharge_record, identity)
hospitals = get_values_of_keys(discharge_record_ie_result, HOSPITAL)
departments = get_values_of_keys(discharge_record_ie_result, DEPARTMENT)
def discharge_task(pk_phhd, discharge_record_ie_result):
hospitals = get_values_of_key(discharge_record_ie_result, IE_KEY['hospital'])
departments = get_values_of_key(discharge_record_ie_result, IE_KEY['department'])
discharge_data = {
"pk_phhd": pk_phhd,
"hospital": handle_hospital(",".join(hospitals)),
"department": handle_department(",".join(departments)),
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
"doctor": handle_doctor(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
"admission_id": handle_admission_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
"age": handle_age(get_best_value_in_keys(discharge_record_ie_result, AGE)),
"content": handle_text(discharge_record_ie_result['ocr_text']),
"name": handle_name(get_best_value_of_key(discharge_record_ie_result, IE_KEY['name'])),
"admission_date_str": handle_original_data(
get_best_value_of_key(discharge_record_ie_result, IE_KEY['admission_date'])),
"discharge_date_str": handle_original_data(
get_best_value_of_key(discharge_record_ie_result, IE_KEY['discharge_date'])),
"doctor": handle_name(get_best_value_of_key(discharge_record_ie_result, IE_KEY['doctor'])),
"admission_id": handle_id(get_best_value_of_key(discharge_record_ie_result, IE_KEY['admission_id'])),
"age": handle_age(get_best_value_of_key(discharge_record_ie_result, IE_KEY['age'])),
}
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
@@ -493,53 +418,270 @@ def discharge_task(pk_phhd, discharge_record, identity):
if best_match:
discharge_data["pk_ylks"] = best_match[2]
save_or_update_ie(ZxIeDischarge, pk_phhd, discharge_data)
return discharge_data
def cost_task(pk_phhd, cost_list, identity):
cost_list_ie_result = information_extraction(COST_IE, cost_list, identity)
def cost_task(pk_phhd, cost_list_ie_result):
cost_data = {
"pk_phhd": pk_phhd,
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES)),
"content": handle_text(cost_list_ie_result['ocr_text']),
"name": handle_name(get_best_value_of_key(cost_list_ie_result, IE_KEY['name'])),
"admission_date_str": handle_original_data(
get_best_value_of_key(cost_list_ie_result, IE_KEY['admission_date'])),
"discharge_date_str": handle_original_data(
get_best_value_of_key(cost_list_ie_result, IE_KEY['discharge_date'])),
"medical_expenses_str": handle_original_data(
get_best_value_of_key(cost_list_ie_result, IE_KEY['medical_expenses']))
}
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])
cost_data["medical_expenses"] = handle_decimal(cost_data["medical_expenses_str"])
if cost_list_ie_result.get(IE_KEY['page']):
page_nums, page_count = parse_page_num(cost_list_ie_result[IE_KEY['page']])
if page_nums:
page_nums_str = [str(num) for num in page_nums]
cost_data['page_nums'] = handle_original_data(','.join(page_nums_str))
cost_data['page_count'] = handle_tiny_int(page_count)
save_or_update_ie(ZxIeCost, pk_phhd, cost_data)
return cost_data
def photo_review(pk_phhd):
settlement_list = []
discharge_record = []
cost_list = []
def parse_pdf_text(settlement_text):
pattern = (r'(?:交款人:(.*?)\n|住院时间:(.*?)至(.*?)\n|\(小写\)(.*?)\n|个人现金支付:(.*?)\n|个人账户支付:(.*?)\n'
r'|个人自费:(.*?)\n|医保类型:(.*?)\n|住院科别:(.*?)\n|住院号:(.*?)\n|票据号码:(.*?)\n|)')
# 查找所有匹配项
matches = re.findall(pattern, settlement_text)
results = {}
keys = ['患者姓名', '入院日期', '出院日期', '费用总额', '个人现金支付', '个人账户支付', '个人自费', '医保类型',
'科室', '住院号', '医保结算单号码']
for match in matches:
for key, value in zip(keys, match):
if value:
results[key] = [[{'text': value, 'probability': 1}]]
settlement_key = ['患者姓名', '入院日期', '出院日期', '费用总额', '个人现金支付', '个人账户支付', '个人自费',
'医保类型', '住院号', '医保结算单号码']
discharge_key = ['科室', '患者姓名', '入院日期', '出院日期', '住院号']
cost_key = ['患者姓名', '入院日期', '出院日期', '费用总额']
settlement_result = {key: copy.copy(results[key]) for key in settlement_key if key in results}
discharge_result = {key: copy.copy(results[key]) for key in discharge_key if key in results}
cost_result = {key: copy.copy(results[key]) for key in cost_key if key in results}
return settlement_result, discharge_result, cost_result
def photo_review(pk_phhd, name):
"""
处理单个报销案子
:param pk_phhd: 报销单主键
:param name: 报销人姓名
"""
settlement_result = defaultdict(list)
discharge_result = defaultdict(list)
cost_result = defaultdict(list)
session = MysqlSession()
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
ZxPhrec.pk_phhd == pk_phhd
).all()
).order_by(ZxPhrec.cRectype, ZxPhrec.rowno).all()
session.close()
for phrec in phrecs:
if phrec.cRectype == "1":
settlement_list.append(phrec)
elif phrec.cRectype == "3":
discharge_record.append(phrec)
elif phrec.cRectype == "4":
cost_list.append(phrec)
# 同一批图的标识
identity = int(time.time())
settlement_task(pk_phhd, settlement_list, identity)
discharge_task(pk_phhd, discharge_record, identity)
cost_task(pk_phhd, cost_list, identity)
set_batch_id(uuid.uuid4().hex)
processed_img_dir = common_util.get_processed_img_path('')
os.makedirs(processed_img_dir, exist_ok=True)
has_pdf = False # 是否获取到了pdf获取到可以直接利用pdf更快的获取信息
better_settlement_path = None
better_cost_path = None
settlement_text = ''
qrcode_img_id = None
for phrec in phrecs:
original_img_path = common_util.get_img_path(phrec.cfjaddress)
if not original_img_path:
img_url = ufile.get_private_url(phrec.cfjaddress)
if not img_url:
continue
original_img_path = common_util.save_to_local(img_url)
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
shutil.copy2(original_img_path, img_path)
# 尝试从二维码中获取高清图片
better_settlement_path, settlement_text, better_cost_path = parse_qrcode(img_path, phrec.cfjaddress)
if better_settlement_path:
has_pdf = True
qrcode_img_id = phrec.cfjaddress
break
discharge_text = ''
if has_pdf:
settlement_result, discharge_result, cost_result = parse_pdf_text(settlement_text)
discharge_ie_result = defaultdict(list)
is_cost_updated = False
for phrec in phrecs:
if phrec.cRectype == '1':
if phrec.cfjaddress == qrcode_img_id:
try:
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
ufile.upload_file(phrec.cfjaddress, better_settlement_path)
except Exception as e:
logging.error("更新结算单pdf图片出错", exc_info=e)
elif phrec.cRectype == '3':
rec_type, ie_result, ocr_text = information_extraction(phrec, pk_phhd)
if rec_type == '出院记录':
discharge_text += ocr_text
for key, value in ie_result.items():
discharge_ie_result[key].append(value)
# 暂不替换费用清单
# elif phrec.cRectype == '4':
# if not is_cost_updated:
# try:
# ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
# ufile.upload_file(phrec.cfjaddress, better_cost_path)
# except Exception as e:
# logging.error("更新费用清单pdf图片出错", exc_info=e)
# finally:
# is_cost_updated = True
# 合并出院记录
for key, value in discharge_ie_result.items():
ie_value = get_best_value_of_key(discharge_ie_result, key)
pdf_value = discharge_result.get(key)[0][0]['text'] if discharge_result.get(key) else ''
similarity_ratio = fuzz.ratio(ie_value, pdf_value)
if similarity_ratio < 60:
discharge_result[key] = [[{'text': ie_value, 'probability': 1}]]
else:
for phrec in phrecs:
rec_type, ie_result, ocr_text = information_extraction(phrec, pk_phhd)
if rec_type == '基本医保结算单':
rec_result = settlement_result
elif rec_type == '出院记录':
rec_result = discharge_result
discharge_text += ocr_text
elif rec_type == '费用清单':
rec_result = cost_result
else:
rec_result = None
if rec_result is not None:
for key, value in ie_result.items():
rec_result[key].append(value)
# 删除多余图片
if os.path.exists(processed_img_dir) and os.path.isdir(processed_img_dir):
shutil.rmtree(processed_img_dir)
settlement_data = settlement_task(pk_phhd, settlement_result)
discharge_data = discharge_task(pk_phhd, discharge_result)
cost_data = cost_task(pk_phhd, cost_result)
# 三项资料完整性判断
# 三项资料缺项判断
review_result = {
'pk_phhd': pk_phhd,
'has_settlement': bool(settlement_result),
'has_discharge': bool(discharge_result),
'has_cost': bool(cost_result),
}
if (review_result['has_settlement'] and settlement_data.get('personal_account_payment')
and settlement_data.get('personal_cash_payment') and settlement_data.get('medical_expenses')):
review_result['has_settlement'] &= (
float(settlement_data['personal_account_payment']) + float(settlement_data['personal_cash_payment'])
< float(settlement_data['medical_expenses'])
)
if has_pdf:
review_result['has_discharge'] &= bool(discharge_text)
# 三项资料缺页判断
page_description = []
if review_result['has_discharge']:
for discharge_item in DISCHARGE_KEY:
if not any(key in discharge_text for key in DISCHARGE_KEY[discharge_item]):
page_description.append(f"《出院记录》缺页")
break
if review_result['has_cost']:
cost_missing_page = {}
if cost_data.get('page_nums') and cost_data.get('page_count'):
page_nums = cost_data['page_nums'].split(',')
required_set = set(range(1, cost_data['page_count'] + 1))
page_set = set([int(num) for num in page_nums])
cost_missing_page = required_set - page_set
if cost_missing_page:
cost_missing_page = sorted(cost_missing_page)
cost_missing_page = [str(num) for num in cost_missing_page]
page_description.append(f"《住院费用清单》,缺第{','.join(cost_missing_page)}")
if page_description:
review_result['full_page'] = False
review_result['page_description'] = ';'.join(page_description)
else:
review_result['full_page'] = True
review_result['integrity'] = (review_result['has_settlement'] and review_result['has_discharge']
and review_result['has_cost'] and review_result['full_page'])
# 三项资料一致性判断
# 姓名一致性
name_list = [settlement_data['name'], discharge_data['name'], cost_data['name']]
if sum(not bool(n) for n in name_list) > 1: # 有2个及以上空值直接认为都不一致
review_result['name_match'] = '0'
else:
unique_name = set(name_list)
if len(unique_name) == 1:
review_result['name_match'] = '1' if name == unique_name.pop() else '5'
elif len(unique_name) == 2:
if settlement_data['name'] != discharge_data['name'] and settlement_data['name'] != cost_data['name']:
review_result['name_match'] = '2'
elif discharge_data['name'] != settlement_data['name'] and discharge_data['name'] != cost_data['name']:
review_result['name_match'] = '3'
else:
review_result['name_match'] = '4'
else:
review_result['name_match'] = '0'
# 住院日期一致性
if (settlement_data['admission_date'] and discharge_data['admission_date']
and settlement_data['discharge_date'] and discharge_data['discharge_date']
and settlement_data['admission_date'] == discharge_data['admission_date']
and settlement_data['discharge_date'] == discharge_data['discharge_date']):
review_result['admission_date_match'] = '1'
else:
review_result['admission_date_match'] = '0'
# 出院日期一致性
discharge_date_list = [settlement_data['discharge_date'], discharge_data['discharge_date'],
cost_data['discharge_date']]
if sum(not bool(d) for d in discharge_date_list) > 1:
review_result['discharge_date_match'] = '0'
else:
unique_discharge_date = set(discharge_date_list)
if len(unique_discharge_date) == 1:
review_result['discharge_date_match'] = '1'
elif len(unique_discharge_date) == 2:
if (settlement_data['discharge_date'] != discharge_data['discharge_date']
and settlement_data['discharge_date'] != cost_data['discharge_date']):
review_result['discharge_date_match'] = '2'
elif (discharge_data['discharge_date'] != settlement_data['discharge_date']
and discharge_data['discharge_date'] != cost_data['discharge_date']):
review_result['discharge_date_match'] = '3'
else:
review_result['discharge_date_match'] = '4'
else:
review_result['discharge_date_match'] = '0'
review_result['consistency'] = (
review_result['name_match'] == '1' and review_result['admission_date_match'] == '1'
and review_result['discharge_date_match'] == '1')
review_result['success'] = review_result['integrity'] and review_result['consistency']
save_or_update_ie(ZxIeReview, pk_phhd, review_result)
def main():
"""
照片审核批量控制
"""
while 1:
session = MysqlSession()
phhds = (session.query(ZxPhhd.pk_phhd)
phhds = (session.query(ZxPhhd.pk_phhd, ZxPhhd.cXm)
.join(ZxPhrec, ZxPhhd.pk_phhd == ZxPhrec.pk_phhd, isouter=True)
.filter(ZxPhhd.exsuccess_flag == "1")
.filter(ZxPhrec.pk_phrec.isnot(None))
@@ -556,14 +698,14 @@ def main():
pk_phhd = phhd.pk_phhd
logging.info(f"开始识别:{pk_phhd}")
start_time = time.time()
photo_review(pk_phhd)
photo_review(pk_phhd, phhd.cXm)
# 识别完成更新标识
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(
exsuccess_flag="8",
ref_id1=HOSTNAME,
checktime=util.get_default_datetime(),
checktime=common_util.get_default_datetime(),
fFSYLFY=time.time() - start_time))
session.execute(update_flag)
session.commit()

View File

@@ -6,7 +6,7 @@ from sqlalchemy.sql.functions import count
from db import MysqlSession
from db.mysql import ZxPhhd, ViewErrorReview
from util import util
from util import common_util
def handle_reason(reason):
@@ -95,5 +95,5 @@ if __name__ == '__main__':
print(result)
with open("photo_review_error_report.txt", 'w', encoding='utf-8') as file:
file.write(json.dumps(result, indent=4, ensure_ascii=False))
file.write(util.get_default_datetime())
file.write(common_util.get_default_datetime())
print("结果已保存。")

View File

@@ -1,11 +1,16 @@
aistudio_sdk==0.2.6
onnxconverter-common==1.15.0
onnxruntime-gpu==1.22.0
OpenCC==1.1.6
paddle2onnx==1.2.3
paddlenlp==3.0.0b4
paddleocr==3.1.1
PyMuPDF==1.26.3
beautifulsoup4==4.12.3 # 网页分析
jieba==0.42.1 # 中文分词
numpy==1.26.4
OpenCC==1.1.9 # 中文繁简转换
opencv-python==4.6.0.66
opencv-python-headless==4.10.0.84
pillow==10.4.0
PyMuPDF==1.24.9 # pdf处理
pymysql==1.1.1
ufile==3.2.11
zxing-cpp==2.3.0
rapidfuzz==3.9.4 #文本相似度
requests==2.32.3
sqlacodegen==2.3.0.post1 # 实体类生成
sqlalchemy==1.4.52 # ORM框架
tenacity==8.5.0 # 重试
ufile==3.2.9 # 云空间
zxing-cpp==2.2.0 # 二维码识别

View File

@@ -0,0 +1,245 @@
### PyCharm template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# 通过卷绑定挂载到容器中
/log
/model
# docker
Dockerfile

View File

@@ -0,0 +1,28 @@
# 使用官方的paddle镜像作为基础
FROM registry.baidubce.com/paddlepaddle/paddle:2.6.1-gpu-cuda12.0-cudnn8.9-trt8.6
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
# 设置时区
TZ=Asia/Shanghai \
# 设置pip镜像地址加快安装速度
PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
# 安装依赖
COPY requirements.txt /app/requirements.txt
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
&& pip install --no-cache-dir -r requirements.txt \
&& pip uninstall -y onnxruntime onnxruntime-gpu \
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# 将当前目录内容复制到容器的/app内
COPY . /app
# 暴露端口
# EXPOSE 8081
# 运行api接口具体接口在命令行或docker-compose.yml文件中定义
ENTRYPOINT ["gunicorn"]

View File

@@ -0,0 +1,21 @@
"""
信息抽取关键词配置
"""
IE_KEY = {
'name': '患者姓名',
'admission_date': '入院日期',
'discharge_date': '出院日期',
'medical_expenses': '费用总额',
'personal_cash_payment': '个人现金支付',
'personal_account_payment': '个人账户支付',
'personal_funded_amount': '自费金额',
'medical_insurance_type': '医保类型',
'hospital': '医院',
'department': '科室',
'doctor': '主治医生',
'admission_id': '住院号',
'settlement_id': '医保结算单号码',
'age': '年龄',
'uppercase_medical_expenses': '大写总额',
'page': '页码',
}

View File

@@ -0,0 +1,30 @@
import logging.config
from flask import Flask, request
from paddleclas import PaddleClas
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
CLAS = PaddleClas(model_name='text_image_orientation')
@app.route(rule='/', methods=['POST'])
@process_request
def main():
"""
判断图片旋转角度,逆时针旋转该角度后为正。可能值['0', '90', '180', '270']
:return: 最有可能的两个角度
"""
img_path = request.form.get('img_path')
clas_result = CLAS.predict(input_data=img_path)
clas_result = next(clas_result)[0]
if clas_result['scores'][0] < 0.5:
return ['0', '90']
return clas_result['label_names']
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5005)

View File

@@ -0,0 +1,31 @@
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
schema = ['基本医保结算单', '出院记录', '费用清单']
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
task_path='model/text_classification', precision='fp32')
@app.route('/', methods=['POST'])
@process_request
def main():
text = request.form.get('text')
cls_result = CLAS(text)
cls_result = cls_result[0].get('predictions')
if cls_result:
cls_result = cls_result[0]
if cls_result['score'] and float(cls_result['score']) < 0.8:
logging.info(f"识别结果置信度{cls_result['score']}过低text: {text}")
return None
return cls_result['label']
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5008)

View File

@@ -0,0 +1,31 @@
import logging.config
import os.path
import cv2
from flask import Flask, request
from log import LOGGING_CONFIG
from paddle_detection import detector
from utils import process_request, parse_img_path
app = Flask(__name__)
@app.route('/', methods=['POST'])
@process_request
def main():
img_path = request.form.get('img_path')
result = detector.get_book_areas(img_path)
dirname, img_name, img_ext = parse_img_path(img_path)
books_path = []
for i in range(len(result)):
save_path = os.path.join(dirname, f'{img_name}.book_{i}.{img_ext}')
cv2.imwrite(save_path, result[i])
books_path.append(save_path)
return books_path
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5006)

View File

@@ -0,0 +1,28 @@
import logging.config
import os
import cv2
from flask import Flask, request
from doc_dewarp import dewarper
from log import LOGGING_CONFIG
from utils import process_request, parse_img_path
app = Flask(__name__)
@app.route('/', methods=['POST'])
@process_request
def main():
img_path = request.form.get('img_path')
img = cv2.imread(img_path)
dewarped_img = dewarper.dewarp_image(img)
dirname, img_name, img_ext = parse_img_path(img_path)
save_path = os.path.join(dirname, f'{img_name}.dewarped.{img_ext}')
cv2.imwrite(save_path, dewarped_img)
return save_path
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5007)

View File

@@ -0,0 +1,167 @@
input/
output/
runs/
*_dump/
*_log/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

View File

@@ -0,0 +1,34 @@
repos:
# Common hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
- id: remove-tabs
name: Tabs remover (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
# For Python files
- repo: https://github.com/psf/black.git
rev: 23.3.0
hooks:
- id: black
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]

View File

@@ -0,0 +1,398 @@
import copy
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .extractor import BasicEncoder
from .position_encoding import build_position_encoding
from .weight_init import weight_init_
class attnLayer(nn.Layer):
def __init__(
self,
d_model,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn_list = nn.LayerList(
[
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
for i in range(2)
]
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2_list = nn.LayerList(
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(p=dropout)
self.dropout2_list = nn.LayerList(
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
)
self.dropout3 = nn.Dropout(p=dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
q = k = self.with_pos_embed(tgt, pos)
tgt2 = self.self_attn(
q.transpose((1, 0, 2)),
k.transpose((1, 0, 2)),
value=tgt.transpose((1, 0, 2)),
attn_mask=tgt_mask,
)
tgt2 = tgt2.transpose((1, 0, 2))
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
memory_list,
self.multihead_attn_list,
self.norm2_list,
self.dropout2_list,
memory_pos,
):
tgt2 = multihead_attn(
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
value=memory.transpose((1, 0, 2)),
attn_mask=memory_mask,
).transpose((1, 0, 2))
tgt = tgt + dropout2(tgt2)
tgt = norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=memory_mask,
)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
return self.forward_post(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransDecoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransDecoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return query_embed
class TransEncoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransEncoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return image
class FlowHead(nn.Layer):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class UpdateBlock(nn.Layer):
def __init__(self, hidden_dim: int = 128):
super(UpdateBlock, self).__init__()
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2D(hidden_dim, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2D(256, 64 * 9, 1, padding=0),
)
def forward(self, image, coords):
mask = 0.25 * self.mask(image)
dflow = self.flow_head(image)
coords = coords + dflow
return mask, coords
def coords_grid(batch, ht, wd):
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
return coords[None].tile([batch, 1, 1, 1])
def upflow8(flow, mode="bilinear"):
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
class OverlapPatchEmbed(nn.Layer):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
patch_size = (
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
)
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2D(
in_chans,
embed_dim,
patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
weight_init_(m, "trunc_normal_", std=0.02)
elif isinstance(m, nn.LayerNorm):
weight_init_(m, "Constant", value=1.0)
elif isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2)
perm = list(range(x.ndim))
perm[1] = 2
perm[2] = 1
x = x.transpose(perm=perm)
x = self.norm(x)
return x, H, W
class GeoTr(nn.Layer):
def __init__(self):
super(GeoTr, self).__init__()
self.hidden_dim = hdim = 256
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
for i in self.encoder_block:
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
for i in self.down_layer:
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
for i in self.decoder_block:
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
for i in self.up_layer:
self.__setattr__(
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.query_embed = nn.Embedding(81, self.hidden_dim)
self.update_block = UpdateBlock(self.hidden_dim)
def initialize_flow(self, img):
N, _, H, W = img.shape
coodslar = coords_grid(N, H, W)
coords0 = coords_grid(N, H // 8, W // 8)
coords1 = coords_grid(N, H // 8, W // 8)
return coodslar, coords0, coords1
def upsample_flow(self, flow, mask):
N, _, H, W = flow.shape
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
mask = F.softmax(mask, axis=2)
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
up_flow = paddle.sum(mask * up_flow, axis=2)
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
return up_flow.reshape([N, 2, 8 * H, 8 * W])
def forward(self, image):
fmap = self.fnet(image)
fmap = F.relu(fmap)
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
fmap3du_ = (
self.__getattr__(self.up_layer[0])(fmap3d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
fmap2du_ = (
self.__getattr__(self.up_layer[1])(fmap2d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
coodslar, coords0, coords1 = self.initialize_flow(image)
coords1 = coords1.detach()
mask, coords1 = self.update_block(fmap_out, coords1)
flow_up = self.upsample_flow(coords1 - coords0, mask)
bm_up = coodslar + flow_up
return bm_up

View File

@@ -0,0 +1,73 @@
# DocTrPP: DocTr++ in PaddlePaddle
## Introduction
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
![demo](https://github.com/GreatV/DocTrPP/assets/17264618/4e491512-bfc4-4e69-a833-fd1c6e17158c)
## Requirements
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
## Training
1. Data Preparation
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
2. Training
```shell
sh train.sh
```
or
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--exist-ok \
--use-vdl
```
3. Load Trained Model and Continue Training
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--resume "runs/train/DocTr++/weights/last.ckpt" \
--exist-ok \
--use-vdl
```
## Test and Inference
Test the dewarp result on a single image:
```shell
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
```
![document image rectification](https://raw.githubusercontent.com/greatv/DocTrPP/main/doc/imgs/document_image_rectification.jpg)
## Export to onnx
```
pip install paddle2onnx
python export.py -m ./best.ckpt --format onnx
```
## Model Download
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).

View File

@@ -0,0 +1,7 @@
import os
from onnxruntime import InferenceSession
MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'model', 'dewarp_model', 'doc_tr_pp.onnx')
DOC_TR = InferenceSession(MODEL_PATH, providers=['CUDAExecutionProvider'], provider_options=[{'device_id': 0}])

View File

@@ -0,0 +1,133 @@
import argparse
import os
import random
import cv2
import hdf5storage as h5
import matplotlib.pyplot as plt
import numpy as np
import paddle
import paddle.nn.functional as F
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
nargs="?",
type=str,
default="~/datasets/doc3d/",
help="Path to the downloaded dataset",
)
parser.add_argument(
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
)
parser.add_argument(
"--output",
nargs="?",
type=str,
default="output.png",
help="Output filename for the image",
)
args = parser.parse_args()
root = os.path.expanduser(args.data_root)
folder = args.folder
dirname = os.path.join(root, "img", str(folder))
choices = [f for f in os.listdir(dirname) if "png" in f]
fname = random.choice(choices)
# Read Image
img_path = os.path.join(dirname, fname)
img = cv2.imread(img_path)
# Read 3D Coords
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# scale wc
# value obtained from the entire dataset
xmx, xmn, ymx, ymn, zmx, zmn = (
1.2539363,
-1.2442188,
1.2396319,
-1.2289206,
0.6436657,
-0.67492497,
)
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
# Read Backward Map
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
bm = h5.loadmat(bm_path)["bm"]
# Read UV Map
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Depth Map
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
# do some clipping and scaling to display it
dmap[dmap > 30.0] = 30
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
# Read Normal Map
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Albedo
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
alb = cv2.imread(alb_path)
# Read Checkerboard Image
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
recon = cv2.imread(recon_path)
# Display image and GTs
# use the backward mapping to dewarp the image
# scale bm to -1.0 to 1.0
bm_ = bm / np.array([448, 448])
bm_ = (bm_ - 0.5) * 2
bm_ = np.reshape(bm_, (1, 448, 448, 2))
bm_ = paddle.to_tensor(bm_, dtype="float32")
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
img_ = np.expand_dims(img_, 0)
img_ = paddle.to_tensor(img_, dtype="float32")
uw = F.grid_sample(img_, bm_)
uw = uw[0].numpy().transpose((1, 2, 0))
f, axrr = plt.subplots(2, 5)
for ax in axrr:
for a in ax:
a.set_xticks([])
a.set_yticks([])
axrr[0][0].imshow(img)
axrr[0][0].title.set_text("image")
axrr[0][1].imshow(wc)
axrr[0][1].title.set_text("3D coords")
axrr[0][2].imshow(bm[:, :, 0])
axrr[0][2].title.set_text("bm 0")
axrr[0][3].imshow(bm[:, :, 1])
axrr[0][3].title.set_text("bm 1")
if uv is None:
uv = np.zeros_like(img)
axrr[0][4].imshow(uv)
axrr[0][4].title.set_text("uv map")
axrr[1][0].imshow(dmap)
axrr[1][0].title.set_text("depth map")
axrr[1][1].imshow(norm)
axrr[1][1].title.set_text("normal map")
axrr[1][2].imshow(alb)
axrr[1][2].title.set_text("albedo")
axrr[1][3].imshow(recon)
axrr[1][3].title.set_text("checkerboard")
axrr[1][4].imshow(uw)
axrr[1][4].title.set_text("gt unwarped")
plt.tight_layout()
plt.savefig(args.output)

View File

@@ -0,0 +1,21 @@
import cv2
import numpy as np
import paddle
from . import DOC_TR
from .utils import to_tensor, to_image
def dewarp_image(image):
img = cv2.resize(image, (288, 288)).astype(np.float32)
y = to_tensor(image)
img = np.transpose(img, (2, 0, 1))
bm = DOC_TR.run(None, {'image': img[None,]})[0]
bm = paddle.to_tensor(bm)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode='bilinear', align_corners=False
)
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
return to_image(out)

View File

@@ -0,0 +1,161 @@
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

View File

@@ -0,0 +1,129 @@
import collections
import os
import random
import albumentations as A
import cv2
import hdf5storage as h5
import numpy as np
import paddle
from paddle import io
# Set random seed
random.seed(12345678)
class Doc3dDataset(io.Dataset):
def __init__(self, root, split="train", is_augment=False, image_size=512):
self.root = os.path.expanduser(root)
self.split = split
self.is_augment = is_augment
self.files = collections.defaultdict(list)
self.image_size = (
image_size if isinstance(image_size, tuple) else (image_size, image_size)
)
# Augmentation
self.augmentation = A.Compose(
[
A.ColorJitter(),
]
)
for split in ["train", "val"]:
path = os.path.join(self.root, split + ".txt")
file_list = []
with open(path, "r") as file:
file_list = [file_id.rstrip() for file_id in file.readlines()]
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
image_name = self.files[self.split][index]
# Read image
image_path = os.path.join(self.root, "img", image_name + ".png")
image = cv2.imread(image_path)
# Read 3D Coordinates
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read backward map
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
bm = h5.loadmat(bm_path)["bm"]
image, bm = self.transform(wc, bm, image)
return image, bm
def tight_crop(self, wc: np.ndarray):
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
np.uint8
)
mask_size = mask.shape
[y, x] = mask.nonzero()
min_x = min(x)
max_x = max(x)
min_y = min(y)
max_y = max(y)
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
s = 10
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
cx1 = random.randint(0, 2 * s)
cx2 = random.randint(0, 2 * s) + 1
cy1 = random.randint(0, 2 * s)
cy2 = random.randint(0, 2 * s) + 1
wc = wc[cy1:-cy2, cx1:-cx2, :]
top: int = min_y - s + cy1
bottom: int = mask_size[0] - max_y - s + cy2
left: int = min_x - s + cx1
right: int = mask_size[1] - max_x - s + cx2
top = np.clip(top, 0, mask_size[0])
bottom = np.clip(bottom, 1, mask_size[0] - 1)
left = np.clip(left, 0, mask_size[1])
right = np.clip(right, 1, mask_size[1] - 1)
return wc, top, bottom, left, right
def transform(self, wc, bm, img):
wc, top, bottom, left, right = self.tight_crop(wc)
img = img[top:-bottom, left:-right, :]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.is_augment:
img = self.augmentation(image=img)["image"]
# resize image
img = cv2.resize(img, self.image_size)
img = img.astype(np.float32) / 255.0
img = img.transpose(2, 0, 1)
# resize bm
bm = bm.astype(np.float32)
bm[:, :, 1] = bm[:, :, 1] - top
bm[:, :, 0] = bm[:, :, 0] - left
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
bm0 = bm0 * self.image_size[0]
bm1 = bm1 * self.image_size[1]
bm = np.stack([bm0, bm1], axis=-1)
img = paddle.to_tensor(img).astype(dtype="float32")
bm = paddle.to_tensor(bm).astype(dtype="float32")
return img, bm

View File

@@ -0,0 +1,66 @@
import argparse
import os
import paddle
from GeoTr import GeoTr
def export(args):
model_path = args.model
imgsz = args.imgsz
format = args.format
model = GeoTr()
checkpoint = paddle.load(model_path)
model.set_state_dict(checkpoint["model"])
model.eval()
dirname = os.path.dirname(model_path)
if format == "static" or format == "onnx":
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
],
full_graph=True,
)
paddle.jit.save(model, os.path.join(dirname, "model"))
if format == "onnx":
onnx_path = os.path.join(dirname, "model.onnx")
os.system(
f"paddle2onnx --model_dir {dirname}"
" --model_filename model.pdmodel"
" --params_filename model.pdiparams"
f" --save_file {onnx_path}"
" --opset_version 11"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export model")
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--imgsz", type=int, default=288, help="The size of input image"
)
parser.add_argument(
"--format",
type=str,
default="static",
help="The format of exported model, which can be static or onnx",
)
args = parser.parse_args()
export(args)

View File

@@ -0,0 +1,110 @@
import paddle.nn as nn
from .weight_init import weight_init_
class ResidualBlock(nn.Layer):
"""Residual Block with custom normalization."""
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
self.relu = nn.ReLU()
if norm_fn == "group":
num_groups = planes // 8
self.norm1 = nn.GroupNorm(num_groups, planes)
self.norm2 = nn.GroupNorm(num_groups, planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups, planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(planes)
self.norm2 = nn.BatchNorm2D(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2D(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(planes)
self.norm2 = nn.InstanceNorm2D(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2D(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Layer):
"""Basic Encoder with custom normalization."""
def __init__(self, output_dim=128, norm_fn="batch"):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(8, 64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
self.relu1 = nn.ReLU()
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(128, stride=2)
self.layer3 = self._make_layer(192, stride=2)
self.conv2 = nn.Conv2D(192, output_dim, 1)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
weight_init_(m, "Constant", value=1, bias_value=0.0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = layer1, layer2
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
return x

View File

@@ -0,0 +1,48 @@
import copy
import os
import matplotlib.pyplot as plt
import paddle.optimizer as optim
from GeoTr import GeoTr
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
"""
Plot the learning rate scheduler
"""
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
lr = []
for _ in range(epochs):
for _ in range(30):
lr.append(scheduler.get_lr())
optimizer.step()
scheduler.step()
epoch = [float(i) / 30.0 for i in range(len(lr))]
plt.figure()
plt.plot(epoch, lr, ".-", label="Learning Rate")
plt.xlabel("epoch")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Scheduler")
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
plt.close()
if __name__ == "__main__":
model = GeoTr()
schaduler = optim.lr.OneCycleLR(
max_learning_rate=1e-4,
total_steps=1950,
phase_pct=0.1,
end_learning_rate=1e-4 / 2.5e5,
)
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
plot_lr_scheduler(
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
)

View File

@@ -0,0 +1,124 @@
import math
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.initializer as init
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
self.tensors = tensors
self.mask = mask
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
class PositionEmbeddingSine(nn.Layer):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, mask):
assert mask is not None
y_embed = mask.cumsum(axis=1, dtype="float32")
x_embed = mask.cumsum(axis=2, dtype="float32")
if self.normalize:
eps = 1e-06
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = paddle.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
return pos
class PositionEmbeddingLearned(nn.Layer):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
init_Constant = init.Uniform()
init_Constant(self.row_embed.weight)
init_Constant(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = paddle.arange(end=w)
j = paddle.arange(end=h)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
paddle.concat(
[
x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]),
],
axis=-1,
)
.transpose([2, 0, 1])
.unsqueeze(0)
.tile([x.shape[0], 1, 1, 1])
)
return pos
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
N_steps = hidden_dim // 2
if position_embedding in ("v2", "sine"):
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {position_embedding}")
return position_embedding

View File

@@ -0,0 +1,69 @@
import argparse
import cv2
import paddle
from GeoTr import GeoTr
from utils import to_image, to_tensor
def run(args):
image_path = args.image
model_path = args.model
output_path = args.output
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
model = GeoTr()
model.set_state_dict(state_dict)
model.eval()
img_org = cv2.imread(image_path)
img = cv2.resize(img_org, (288, 288))
x = to_tensor(img)
y = to_tensor(img_org)
bm = model(x)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = bm.transpose([0, 2, 3, 1])
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
out_image = to_image(out)
cv2.imwrite(output_path, out_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="predict")
parser.add_argument(
"--image",
"-i",
nargs="?",
type=str,
default="",
help="The path of image",
)
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--output",
"-o",
nargs="?",
type=str,
default="",
help="The path of output",
)
args = parser.parse_args()
print(args)
run(args)

View File

@@ -0,0 +1,7 @@
hdf5storage
loguru
numpy
scipy
opencv-python
matplotlib
albumentations

View File

@@ -0,0 +1,57 @@
import argparse
import glob
import os
import random
from pathlib import Path
from loguru import logger
random.seed(1234567)
def run(args):
data_root = os.path.expanduser(args.data_root)
ratio = args.train_ratio
data_path = os.path.join(data_root, "img", "*", "*.png")
img_list = glob.glob(data_path, recursive=True)
sorted(img_list)
random.shuffle(img_list)
train_size = int(len(img_list) * ratio)
train_text_path = os.path.join(data_root, "train.txt")
with open(train_text_path, "w") as file:
for item in img_list[:train_size]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
val_text_path = os.path.join(data_root, "val.txt")
with open(val_text_path, "w") as file:
for item in img_list[train_size:]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
logger.info(f"TRAIN LABEL: {train_text_path}")
logger.info(f"VAL LABEL: {val_text_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="~/datasets/doc3d",
help="Data path to load data",
)
parser.add_argument(
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
)
args = parser.parse_args()
logger.info(args)
run(args)

View File

@@ -0,0 +1,381 @@
import argparse
import inspect
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from loguru import logger
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle_msssim import ms_ssim, ssim
from doc3d_dataset import Doc3dDataset
from GeoTr import GeoTr
from utils import to_image
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
RANK = int(os.getenv("RANK", -1))
def init_seeds(seed=0, deterministic=False):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if deterministic:
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
# i.e. colorstr('blue', 'hello world')
*args, string = (
input if len(input) > 1 else ("blue", "bold", input[0])
) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path,
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
)
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def train(args):
save_dir = Path(args.save_dir)
use_vdl = args.use_vdl
if use_vdl:
from visualdl import LogWriter
log_dir = save_dir / "vdl"
vdl_writer = LogWriter(str(log_dir))
# Directories
weights_dir = save_dir / "weights"
weights_dir.parent.mkdir(parents=True, exist_ok=True)
last = weights_dir / "last.ckpt"
best = weights_dir / "best.ckpt"
# Hyperparameters
# Config
init_seeds(args.seed)
# Train loader
train_dataset = Doc3dDataset(
args.data_root,
split="train",
is_augment=True,
image_size=args.img_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
)
# Validation loader
val_dataset = Doc3dDataset(
args.data_root,
split="val",
is_augment=False,
image_size=args.img_size,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=args.workers
)
# Model
model = GeoTr()
if use_vdl:
vdl_writer.add_graph(
model,
input_spec=[
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
],
)
# Data Parallel Mode
if RANK == -1 and paddle.device.cuda.device_count() > 1:
model = paddle.DataParallel(model)
# Scheduler
scheduler = optim.lr.OneCycleLR(
max_learning_rate=args.lr,
total_steps=args.epochs * len(train_loader),
phase_pct=0.1,
end_learning_rate=args.lr / 2.5e5,
)
# Optimizer
optimizer = optim.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
)
# loss function
l1_loss_fn = nn.L1Loss()
mse_loss_fn = nn.MSELoss()
# Resume
best_fitness, start_epoch = 0.0, 0
if args.resume:
ckpt = paddle.load(args.resume)
model.set_state_dict(ckpt["model"])
optimizer.set_state_dict(ckpt["optimizer"])
scheduler.set_state_dict(ckpt["scheduler"])
best_fitness = ckpt["best_fitness"]
start_epoch = ckpt["epoch"] + 1
# Train
for epoch in range(start_epoch, args.epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img = paddle.to_tensor(img) # NCHW
target = paddle.to_tensor(target) # NHWC
pred = model(img) # NCHW
pred_nhwc = pred.transpose([0, 2, 3, 1])
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
if use_vdl:
vdl_writer.add_scalar(
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/Learning Rate",
float(scheduler.get_lr()),
epoch * len(train_loader) + i,
)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
if i % 10 == 0:
logger.info(
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
)
# Validation
model.eval()
with paddle.no_grad():
avg_ssim = paddle.zeros([])
avg_ms_ssim = paddle.zeros([])
avg_l1_loss = paddle.zeros([])
avg_mse_loss = paddle.zeros([])
for i, (img, target) in enumerate(val_loader):
img = paddle.to_tensor(img)
target = paddle.to_tensor(target)
pred = model(img)
pred_nhwc = pred.transpose([0, 2, 3, 1])
# predict image
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
# calculate ssim
ssim_val = ssim(out, out_gt, data_range=1.0)
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
# calculate fitness
avg_ssim += ssim_val
avg_ms_ssim += ms_ssim_val
avg_l1_loss += loss
avg_mse_loss += mse_loss
if i % 10 == 0:
logger.info(
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
)
if use_vdl and i == 0:
img_0 = to_image(out[0])
img_gt_0 = to_image(out_gt[0])
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
img_1 = to_image(out[1])
img_gt_1 = to_image(out_gt[1])
img_gt_1 = img_gt_1.astype("uint8")
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
img_2 = to_image(out[2])
img_gt_2 = to_image(out_gt[2])
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
avg_ssim /= len(val_loader)
avg_ms_ssim /= len(val_loader)
avg_l1_loss /= len(val_loader)
avg_mse_loss /= len(val_loader)
if use_vdl:
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
# Save
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_fitness": best_fitness,
"epoch": epoch,
}
paddle.save(ckpt, str(last))
if best_fitness < avg_ssim:
best_fitness = avg_ssim
paddle.save(ckpt, str(best))
if use_vdl:
vdl_writer.close()
def main(args):
print_args(vars(args))
args.save_dir = str(
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
)
train(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--data-root",
nargs="?",
type=str,
default="~/datasets/doc3d",
help="The root path of the dataset",
)
parser.add_argument(
"--img-size",
nargs="?",
type=int,
default=288,
help="The size of the input image",
)
parser.add_argument(
"--epochs",
nargs="?",
type=int,
default=65,
help="The number of training epochs",
)
parser.add_argument(
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
)
parser.add_argument(
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
)
parser.add_argument(
"--resume",
nargs="?",
type=str,
default=None,
help="Path to previous saved model to restart from",
)
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
parser.add_argument(
"--project", default=ROOT / "runs/train", help="save to project/name"
)
parser.add_argument("--name", default="exp", help="save to project/name")
parser.add_argument(
"--exist-ok",
action="store_true",
help="existing project/name ok, do not increment",
)
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,10 @@
export OPENCV_IO_ENABLE_OPENEXR=1
export FLAGS_logtostderr=0
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 1e-4 \
--exist-ok \
--use-vdl

View File

@@ -0,0 +1,67 @@
import numpy as np
import paddle
from paddle.nn import functional as F
def to_tensor(img: np.ndarray):
"""
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
Args:
img (numpy.ndarray): The input image as a numpy array.
Returns:
out (paddle.Tensor): The output tensor.
"""
img = img[:, :, ::-1]
img = img.astype("float32") / 255.0
img = img.transpose(2, 0, 1)
out: paddle.Tensor = paddle.to_tensor(img)
out = paddle.unsqueeze(out, axis=0)
return out
def to_image(x: paddle.Tensor):
"""
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
Args:
x (paddle.Tensor): The input tensor.
Returns:
out (numpy.ndarray): The output image as a numpy array.
"""
out: np.ndarray = x.squeeze().numpy()
out = out.transpose(1, 2, 0)
out = out * 255.0
out = out.astype("uint8")
out = out[:, :, ::-1]
return out
def unwarp(img, bm, bm_data_format="NCHW"):
"""
Unwarp an image using a flow field.
Args:
img (paddle.Tensor): The input image.
bm (paddle.Tensor): The flow field.
Returns:
out (paddle.Tensor): The output image.
"""
_, _, h, w = img.shape
if bm_data_format == "NHWC":
bm = bm.transpose([0, 3, 1, 2])
# NCHW
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
# NHWC
bm = bm.transpose([0, 2, 3, 1])
# NCHW
out = F.grid_sample(img, bm)
return out

View File

@@ -0,0 +1,152 @@
import math
import numpy as np
import paddle
import paddle.nn.initializer as init
from scipy import special
def weight_init_(
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
):
"""
In-place params init function.
Usage:
.. code-block:: python
import paddle
import numpy as np
data = np.ones([3, 4], dtype='float32')
linear = paddle.nn.Linear(4, 4)
input = paddle.to_tensor(data)
print(linear.weight)
linear(input)
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
print(linear.weight)
"""
if hasattr(layer, "weight") and layer.weight is not None:
getattr(init, func)(**kwargs)(layer.weight)
if weight_name is not None:
# override weight name
layer.weight.name = weight_name
if hasattr(layer, "bias") and layer.bias is not None:
init.Constant(bias_value)(layer.bias)
if bias_name is not None:
# override bias name
layer.bias.name = bias_name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with paddle.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
tmp = np.random.uniform(
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
).astype(np.float32)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tmp = special.erfinv(tmp)
# Transform to proper mean, std
tmp *= std * math.sqrt(2.0)
tmp += mean
# Clamp to ensure it's in the proper range
tmp = np.clip(tmp, a, b)
tensor.set_value(paddle.to_tensor(tmp))
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor "
"with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
)
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with paddle.no_grad():
paddle.nn.initializer.Normal(0, std)(tensor)
return tensor

View File

@@ -0,0 +1,36 @@
import json
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
from __init__ import IE_KEY
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
COST_LIST_SCHEMA = tuple(IE_KEY[key] for key in [
'name', 'admission_date', 'discharge_date', 'medical_expenses', 'page'
])
COST = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base',
task_path='model/cost_list_model', layout_analysis=False, precision='fp16')
@app.route('/', methods=['POST'], endpoint='cost')
@process_request
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return COST({'doc': img_path, 'layout': json.loads(layout)})
@app.route('/text', methods=['POST'])
@process_request
def text():
t = request.form.get('text')
return COST(t)
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5004)

View File

@@ -0,0 +1,36 @@
import json
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
from __init__ import IE_KEY
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
DISCHARGE_RECORD_SCHEMA = tuple(IE_KEY[key] for key in [
'hospital', 'department', 'name', 'admission_date', 'discharge_date', 'doctor', 'admission_id', 'age'
])
DISCHARGE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, model='uie-x-base',
task_path='model/discharge_record_model', layout_analysis=False, precision='fp16')
@app.route('/', methods=['POST'], endpoint='discharge')
@process_request
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return DISCHARGE({'doc': img_path, 'layout': json.loads(layout)})
@app.route('/text', methods=['POST'])
@process_request
def text():
t = request.form.get('text')
return DISCHARGE(t)
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5003)

View File

@@ -0,0 +1,38 @@
import json
import logging.config
from flask import Flask, request
from paddlenlp import Taskflow
from __init__ import IE_KEY
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
SETTLEMENT_LIST_SCHEMA = tuple(IE_KEY[key] for key in [
'name', 'admission_date', 'discharge_date', 'medical_expenses', 'personal_cash_payment',
'personal_account_payment', 'personal_funded_amount', 'medical_insurance_type', 'admission_id', 'settlement_id',
'uppercase_medical_expenses'
])
SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA, model='uie-x-base',
task_path='model/settlement_list_model', layout_analysis=False, precision='fp16')
@app.route('/', methods=['POST'], endpoint='settlement')
@process_request
def main():
img_path = request.form.get('img_path')
layout = request.form.get('layout')
return SETTLEMENT_IE({'doc': img_path, 'layout': json.loads(layout)})
@app.route('/text', methods=['POST'])
@process_request
def text():
t = request.form.get('text')
return SETTLEMENT_IE(t)
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5002)

View File

@@ -0,0 +1,70 @@
import os
import socket
# 获取主机名,方便区分容器
HOSTNAME = socket.gethostname()
# 检测日志文件的路径是否存在,不存在则创建
LOG_PATHS = [
f'log/{HOSTNAME}/error',
]
for path in LOG_PATHS:
if not os.path.exists(path):
os.makedirs(path)
# 配置字典
LOGGING_CONFIG = {
'version': 1, # 必需,指定配置格式的版本
'disable_existing_loggers': False, # 是否禁用已经存在的logger实例
# formatters定义了不同格式的日志样式
'formatters': {
'standard': {
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S',
},
},
# handlers定义了不同类型的日志处理器
'handlers': {
'console': {
'class': 'logging.StreamHandler', # 控制台处理器
'level': 'DEBUG',
'formatter': 'standard',
'stream': 'ext://sys.stdout', # 输出到标准输出默认编码跟随系统一般为UTF-8
},
'file': {
'class': 'logging.handlers.TimedRotatingFileHandler', # 文件处理器,支持日志滚动
'level': 'INFO',
'formatter': 'standard',
'filename': f'log/{HOSTNAME}/fcb_photo_review.log', # 日志文件路径
'when': 'midnight',
'interval': 1,
'backupCount': 14, # 保留的备份文件数量
'encoding': 'utf-8', # 显式指定文件编码为UTF-8以支持中文
},
'error': {
'class': 'logging.handlers.TimedRotatingFileHandler',
'level': 'INFO',
'formatter': 'standard',
'filename': f'log/{HOSTNAME}/error/fcb_photo_review_error.log',
'when': 'midnight',
'interval': 1,
'backupCount': 14,
'encoding': 'utf-8',
},
},
# loggers定义了日志记录器
'loggers': {
'': { # 根记录器
'handlers': ['console', 'file'], # 关联的处理器
'level': 'DEBUG', # 根记录器的级别
'propagate': False, # 是否向上级传播日志信息
},
'error': {
'handlers': ['console', 'file', 'error'],
'level': 'DEBUG',
'propagate': False,
},
},
}

View File

@@ -0,0 +1 @@
住院费用清单信息抽取微调模型存放目录

View File

@@ -0,0 +1 @@
图片扭曲矫正模型存放目录

View File

@@ -0,0 +1 @@
出院记录信息抽取微调模型存放目录

View File

@@ -0,0 +1 @@
文档检测模型存放目录

View File

@@ -0,0 +1 @@
基本医保结算单信息抽取微调模型存放目录

View File

@@ -0,0 +1 @@
文本分类模型存放目录

View File

@@ -0,0 +1,24 @@
import logging.config
from flask import Flask, request
from paddleocr import PaddleOCR
from log import LOGGING_CONFIG
from utils import process_request
app = Flask(__name__)
# 如果不希望识别出空格可以设置use_space_char=False。做此项设置一定要测试2.7.3版本此项设置有bug会导致识别失败
OCR = PaddleOCR(use_angle_cls=False, show_log=False, det_db_thresh=0.1, det_db_box_thresh=0.3, det_limit_side_len=1248,
drop_score=0.3)
@app.route('/', methods=['POST'])
@process_request
def main():
img_path = request.form.get('img_path')
return OCR.ocr(img_path, cls=False)
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
app.run('0.0.0.0', 5001)

Some files were not shown because too many files have changed in this diff Show More