Compare commits

61 Commits
deploy ... dev

Author SHA1 Message Date
97903b2722 新增功能,获取案子全部原始图片 2025-11-27 08:44:56 +08:00
670172e79e 更新OCR版本,Bata版,还不能上线 2025-09-15 15:41:30 +08:00
d266c2828c 优化镜像构建 2025-09-12 14:42:51 +08:00
bc4c95c18c 添加判断截图方法 2025-09-12 14:40:52 +08:00
af08078380 关闭清晰度测试 2025-09-12 14:11:54 +08:00
8984948107 如果没有结算单号就填住院号 2025-09-12 14:09:46 +08:00
843511b6f3 添加测试环境更新脚本 2025-09-01 15:04:24 +08:00
a5d7da6536 调整测试镜像 2025-09-01 14:59:50 +08:00
99d555aba9 调整测试镜像 2025-09-01 14:54:13 +08:00
7ca5b9d908 更新测试环境镜像版本:0.0.8 2025-08-28 16:06:14 +08:00
5e5e35fd9f 调整测试镜像 2025-08-28 15:12:15 +08:00
d080b66ebf 新增OCR相关文档 2025-08-20 13:58:50 +08:00
1e8ef432df 修正ie表更新的修改人 2025-08-20 10:55:25 +08:00
a6515e971b 删除特殊调整的案子查询 2025-08-20 01:51:51 +00:00
e3fd3f618f 更新版本1.15.5 2025-08-20 01:51:51 +00:00
b6ae36a8ec 临时调整识别案子范围,尝试修正模型替换导致的识别率下降 2025-08-20 01:51:51 +00:00
a99a615e22 临时调整识别案子范围,尝试修正模型替换导致的识别率下降 2025-08-20 01:51:51 +00:00
5f645b5b4b dev分支添加开发环境配置文件 2025-08-19 10:19:34 +08:00
1625f0294f 从主分支删除开发环境相关内容 2025-08-19 10:11:12 +08:00
f19f8cbcae 复制项目内容 2025-08-18 16:38:53 +08:00
ba3e23d185 修正grep警告 2025-08-18 16:01:01 +08:00
0abf7abb5b 修正语言包警告 2025-08-18 15:15:29 +08:00
47ac6aadbe 修正语言包警告 2025-08-18 15:10:39 +08:00
09ede1af25 删除doc_dewarp扭曲矫正 2025-08-18 14:21:10 +08:00
e40d963bf5 调整开发环境镜像构建 2025-08-18 14:05:38 +08:00
34344edd29 调整开发环境镜像构建 2025-08-18 13:58:43 +08:00
b387db1e08 添加远程开发环境容器 2025-08-18 13:29:23 +08:00
88ca27928f 删除过期数据 2025-08-13 10:50:15 +08:00
109a5e9444 删除过期数据 2025-08-12 11:13:14 +08:00
ab5f78cc7b 测试图片清晰度模型效果 2025-07-30 15:11:01 +08:00
04358ee646 增加图片损坏的判断 2025-06-26 08:58:57 +08:00
a67c53f470 增加图片损坏的判断 2025-06-26 08:48:11 +08:00
cd604bc1eb 修正图片可能因为某些原因获取不到而无法继续的问题 2025-05-21 11:26:01 +08:00
0de9fc14b5 修正高清图片为空的判断 2025-04-01 15:24:32 +08:00
5287df4959 添加图片清晰度测试,保存结果对照 2025-04-01 15:11:40 +08:00
3e9c0c99b9 Revert "测试图片清晰度"
This reverts commit a740f16e6b.
2025-04-01 14:47:56 +08:00
a740f16e6b 测试图片清晰度 2025-04-01 14:27:54 +08:00
b9606771cf 判断图片清晰度 2025-04-01 14:24:19 +08:00
110bc57abc Revert "测试清晰度模型运行情况"
This reverts commit f965bc4289.
2025-03-20 16:15:33 +08:00
f965bc4289 测试清晰度模型运行情况 2025-03-20 15:26:54 +08:00
8b6bf03d76 二维码识别替换前备份 2025-03-20 10:23:04 +08:00
73536aea89 删除一个识别服务 2025-02-17 16:08:30 +08:00
12f6554d8c 修正全角数字存入数据库失败的问题 2025-02-17 13:00:52 +08:00
be94bc7f09 删除一个识别服务 2025-02-17 10:41:54 +08:00
8307fcd549 添加两个识别服务 2025-02-17 10:09:11 +08:00
ab2dbf7c15 删除两个涂抹服务 2025-02-08 10:59:18 +08:00
dac7a1b5ce 添加两个涂抹服务 2025-02-08 09:55:17 +08:00
935fa26067 修正无效方法调用 2025-01-22 16:25:04 +08:00
16ab4c78d5 调整各字段长度限制,与其他表长度保持一致 2025-01-22 16:19:52 +08:00
ed77b8ed82 调整医生姓名长度限制 2025-01-22 15:44:14 +08:00
f4ec3b1eb4 paddlelite2.10可成功转换 2025-01-15 16:47:21 +08:00
a93f68413c 批量下载图片 2025-01-15 16:44:32 +08:00
0a947274dc paddle模型转nb端侧部署模型 2025-01-08 16:04:57 +08:00
f7540c4574 修正计算误差导致分割产生无效图片 2025-01-02 13:05:10 +08:00
5e6a471954 增加ocr结果存表 2024-12-24 14:55:43 +08:00
96b8a06e6c 优化医生名字的处理,没有时填充无 2024-12-03 13:27:18 +08:00
be27f753ba 优化数据字段的处理,增加限制条件 2024-12-03 13:12:58 +08:00
8ea5420520 捕获到内存错误主动抛出中止程序,以释放显存 2024-11-28 10:27:44 +08:00
9749577c5a 优化docker服务,删除无效的det_api 2024-11-05 17:30:16 +08:00
e0f6b82dad 更新版本号 2024-11-05 17:21:37 +08:00
8223759fdf 调整自动识别的文字识别模型 2024-11-05 17:18:12 +08:00
52 changed files with 1107 additions and 2401 deletions

View File

@@ -15,6 +15,7 @@ ENV PYTHONUNBUFFERED=1 \
COPY requirements.txt /app/requirements.txt
COPY packages /app/packages
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
&& python3 -m pip install --upgrade pip \
&& pip install --no-cache-dir -r requirements.txt \
&& pip uninstall -y onnxruntime onnxruntime-gpu \
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/

33
Dockerfile.dev Normal file
View File

@@ -0,0 +1,33 @@
# 使用官方的paddle镜像作为基础
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
# 设置时区
TZ=Asia/Shanghai \
# 设置pip镜像地址加快安装速度
PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
# 安装language-pack-en和openssh-server
RUN apt update && \
apt install -y language-pack-en && \
apt install -y openssh-server
# 配置SSH服务
RUN mkdir /var/run/sshd && \
# 设置root密码可根据需要修改
echo 'root:fcb0102' | chpasswd && \
# 允许root登录SSH
sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
# 将当前目录内容复制到容器的/app内
COPY . /app
# 暴露22端口
EXPOSE 22
# 启动SSH服务
CMD ["/usr/sbin/sshd", "-D"]

View File

@@ -126,3 +126,7 @@ bash update.sh
2. 新增扭曲矫正功能
21. 版本号1.14.0
1. 新增二维码识别替换高清图片功能
22. 版本号1.15.0
1. 新增图片清晰度测试
23. 版本号1.16.0
1. 更新paddle框架至3.0

View File

@@ -14,7 +14,7 @@ from ucloud import ufile
from util import image_util
def check_ie_result(pk_phhd):
def check_ie_result(pk_phhd, need_to_annotation=True):
os.makedirs(f"./check_result/{pk_phhd}", exist_ok=True)
json_result = {"pk_phhd": pk_phhd}
session = MysqlSession()
@@ -46,10 +46,13 @@ def check_ie_result(pk_phhd):
ZxPhrec.pk_phhd == pk_phhd).all()
for phrec in phrecs:
img_name = phrec.cfjaddress
img_path = ufile.get_private_url(img_name, "drg2015")
if not img_path:
img_path = ufile.get_private_url(img_name)
response = requests.get(img_path)
image = Image.open(BytesIO(response.content)).convert("RGB")
if need_to_annotation:
font_size = image.width * image.height / 200000
font = ImageFont.truetype("./font/simfang.ttf", size=font_size)
@@ -85,6 +88,9 @@ def check_ie_result(pk_phhd):
draw.text((box[0], box[3]), value["text"], fill="blue", font=font) # 在矩形下方绘制文本
os.makedirs(f"./check_result/{pk_phhd}/{ocr_item.id}", exist_ok=True)
image.save(f"./check_result/{pk_phhd}/{ocr_item.id}/{img_name}")
else:
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
session.close()
# 自定义JSON处理器
@@ -99,4 +105,4 @@ def check_ie_result(pk_phhd):
if __name__ == '__main__':
check_ie_result(0)
check_ie_result(5640504)

View File

@@ -19,5 +19,5 @@ DB_URL = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}'
SHOW_SQL = False
Engine = create_engine(DB_URL, echo=SHOW_SQL)
Base = declarative_base(Engine)
Base = declarative_base()
MysqlSession = sessionmaker(bind=Engine)

View File

@@ -1,5 +1,5 @@
# coding: utf-8
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary, Text
from sqlalchemy.dialects.mysql import BIT, CHAR, INTEGER, TINYINT, VARCHAR
from db import Base
@@ -56,13 +56,16 @@ class ZxIeCost(Base):
pk_ie_cost = Column(INTEGER(11), primary_key=True, comment='费用明细信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
name = Column(String(30), comment='患者姓名')
content = Column(Text, comment='详细内容')
name = Column(String(20), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
discharge_date = Column(Date, comment='出院日期')
medical_expenses_str = Column(String(255), comment='费用总额字符串')
medical_expenses = Column(DECIMAL(18, 2), comment='费用总额')
page_nums = Column(String(255), comment='页码')
page_count = Column(TINYINT(4), comment='页数')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -94,19 +97,19 @@ class ZxIeDischarge(Base):
pk_ie_discharge = Column(INTEGER(11), primary_key=True, comment='出院记录信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
content = Column(String(5000), comment='详细内容')
hospital = Column(String(255), comment='医院')
content = Column(Text, comment='详细内容')
hospital = Column(String(200), comment='医院')
pk_yljg = Column(INTEGER(11), comment='医院主键')
department = Column(String(255), comment='科室')
department = Column(String(200), comment='科室')
pk_ylks = Column(INTEGER(11), comment='科室主键')
name = Column(String(30), comment='患者姓名')
name = Column(String(20), comment='患者姓名')
age = Column(INTEGER(3), comment='年龄')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
discharge_date = Column(Date, comment='出院日期')
doctor = Column(String(30), comment='主治医生')
admission_id = Column(String(50), comment='住院号')
doctor = Column(String(20), comment='主治医生')
admission_id = Column(String(20), comment='住院号')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -138,7 +141,7 @@ class ZxIeSettlement(Base):
pk_ie_settlement = Column(INTEGER(11), primary_key=True, comment='结算清单信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
name = Column(String(30), comment='患者姓名')
name = Column(String(20), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
@@ -152,9 +155,9 @@ class ZxIeSettlement(Base):
personal_funded_amount_str = Column(String(255), comment='自费金额字符串')
personal_funded_amount = Column(DECIMAL(18, 2), comment='自费金额')
medical_insurance_type_str = Column(String(255), comment='医保类型字符串')
medical_insurance_type = Column(String(40), comment='医保类型')
admission_id = Column(String(50), comment='住院号')
settlement_id = Column(String(50), comment='医保结算单号码')
medical_insurance_type = Column(String(10), comment='医保类型')
admission_id = Column(String(20), comment='住院号')
settlement_id = Column(String(30), comment='医保结算单号码')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),

74
delete_deprecated_data.py Normal file
View File

@@ -0,0 +1,74 @@
# 删除本地数据库中的过期数据
import logging.config
from datetime import datetime, timedelta
from db import MysqlSession
from db.mysql import ZxPhhd, ZxIeCost, ZxIeDischarge, ZxIeResult, ZxIeSettlement, ZxPhrec
from log import LOGGING_CONFIG
# 过期时间(不建议小于1个月)
EXPIRATION_DAYS = 183
# 批量删除数量(最好在1000~10000之间)
BATCH_SIZE = 5000
# 数据库会话对象
session = None
def batch_delete_by_pk_phhd(model, pk_phhds):
"""
批量删除指定模型中主键在指定列表中的数据
参数:
modelSQLAlchemy模型类对应数据库表
pk_phhds待删除的主键值列表
返回:
删除的记录数量
"""
delete_count = (
session.query(model)
.filter(model.pk_phhd.in_(pk_phhds))
.delete(synchronize_session=False)
)
session.commit()
logging.getLogger("sql").info(f"{model.__tablename__}成功删除{delete_count}条数据")
return delete_count
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
deadline = datetime.now() - timedelta(days=EXPIRATION_DAYS)
double_deadline = deadline - timedelta(days=EXPIRATION_DAYS)
session = MysqlSession()
try:
while 1:
# 已经走完所有流程的案子,超过过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.paint_flag == "9")
.filter(ZxPhhd.billdate < deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有通过审核,可能会重拍补拍上传的案子,超过两倍过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.exsuccess_flag == "9")
.filter(ZxPhhd.paint_flag == "0")
.filter(ZxPhhd.billdate < double_deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有符合条件的数据,退出循环
break
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
logging.getLogger("sql").info(f"过期的pk_phhd有{','.join(map(str, pk_phhd_values))}")
batch_delete_by_pk_phhd(ZxPhrec, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeResult, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeSettlement, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeDischarge, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeCost, pk_phhd_values)
batch_delete_by_pk_phhd(ZxPhhd, pk_phhd_values)
except Exception as e:
session.rollback()
logging.getLogger('error').error('过期数据删除失败!', exc_info=e)
finally:
session.close()

167
doc_dewarp/.gitignore vendored
View File

@@ -1,167 +0,0 @@
input/
output/
runs/
*_dump/
*_log/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

View File

@@ -1,34 +0,0 @@
repos:
# Common hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
- id: remove-tabs
name: Tabs remover (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
# For Python files
- repo: https://github.com/psf/black.git
rev: 23.3.0
hooks:
- id: black
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]

View File

@@ -1,398 +0,0 @@
import copy
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .extractor import BasicEncoder
from .position_encoding import build_position_encoding
from .weight_init import weight_init_
class attnLayer(nn.Layer):
def __init__(
self,
d_model,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn_list = nn.LayerList(
[
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
for i in range(2)
]
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2_list = nn.LayerList(
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(p=dropout)
self.dropout2_list = nn.LayerList(
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
)
self.dropout3 = nn.Dropout(p=dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
q = k = self.with_pos_embed(tgt, pos)
tgt2 = self.self_attn(
q.transpose((1, 0, 2)),
k.transpose((1, 0, 2)),
value=tgt.transpose((1, 0, 2)),
attn_mask=tgt_mask,
)
tgt2 = tgt2.transpose((1, 0, 2))
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
memory_list,
self.multihead_attn_list,
self.norm2_list,
self.dropout2_list,
memory_pos,
):
tgt2 = multihead_attn(
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
value=memory.transpose((1, 0, 2)),
attn_mask=memory_mask,
).transpose((1, 0, 2))
tgt = tgt + dropout2(tgt2)
tgt = norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=memory_mask,
)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
return self.forward_post(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransDecoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransDecoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return query_embed
class TransEncoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransEncoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return image
class FlowHead(nn.Layer):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class UpdateBlock(nn.Layer):
def __init__(self, hidden_dim: int = 128):
super(UpdateBlock, self).__init__()
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2D(hidden_dim, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2D(256, 64 * 9, 1, padding=0),
)
def forward(self, image, coords):
mask = 0.25 * self.mask(image)
dflow = self.flow_head(image)
coords = coords + dflow
return mask, coords
def coords_grid(batch, ht, wd):
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
return coords[None].tile([batch, 1, 1, 1])
def upflow8(flow, mode="bilinear"):
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
class OverlapPatchEmbed(nn.Layer):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
patch_size = (
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
)
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2D(
in_chans,
embed_dim,
patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
weight_init_(m, "trunc_normal_", std=0.02)
elif isinstance(m, nn.LayerNorm):
weight_init_(m, "Constant", value=1.0)
elif isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2)
perm = list(range(x.ndim))
perm[1] = 2
perm[2] = 1
x = x.transpose(perm=perm)
x = self.norm(x)
return x, H, W
class GeoTr(nn.Layer):
def __init__(self):
super(GeoTr, self).__init__()
self.hidden_dim = hdim = 256
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
for i in self.encoder_block:
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
for i in self.down_layer:
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
for i in self.decoder_block:
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
for i in self.up_layer:
self.__setattr__(
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.query_embed = nn.Embedding(81, self.hidden_dim)
self.update_block = UpdateBlock(self.hidden_dim)
def initialize_flow(self, img):
N, _, H, W = img.shape
coodslar = coords_grid(N, H, W)
coords0 = coords_grid(N, H // 8, W // 8)
coords1 = coords_grid(N, H // 8, W // 8)
return coodslar, coords0, coords1
def upsample_flow(self, flow, mask):
N, _, H, W = flow.shape
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
mask = F.softmax(mask, axis=2)
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
up_flow = paddle.sum(mask * up_flow, axis=2)
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
return up_flow.reshape([N, 2, 8 * H, 8 * W])
def forward(self, image):
fmap = self.fnet(image)
fmap = F.relu(fmap)
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
fmap3du_ = (
self.__getattr__(self.up_layer[0])(fmap3d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
fmap2du_ = (
self.__getattr__(self.up_layer[1])(fmap2d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
coodslar, coords0, coords1 = self.initialize_flow(image)
coords1 = coords1.detach()
mask, coords1 = self.update_block(fmap_out, coords1)
flow_up = self.upsample_flow(coords1 - coords0, mask)
bm_up = coodslar + flow_up
return bm_up

View File

@@ -1,73 +0,0 @@
# DocTrPP: DocTr++ in PaddlePaddle
## Introduction
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
![demo](https://github.com/GreatV/DocTrPP/assets/17264618/4e491512-bfc4-4e69-a833-fd1c6e17158c)
## Requirements
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
## Training
1. Data Preparation
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
2. Training
```shell
sh train.sh
```
or
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--exist-ok \
--use-vdl
```
3. Load Trained Model and Continue Training
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--resume "runs/train/DocTr++/weights/last.ckpt" \
--exist-ok \
--use-vdl
```
## Test and Inference
Test the dewarp result on a single image:
```shell
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
```
![document image rectification](https://raw.githubusercontent.com/greatv/DocTrPP/main/doc/imgs/document_image_rectification.jpg)
## Export to onnx
```
pip install paddle2onnx
python export.py -m ./best.ckpt --format onnx
```
## Model Download
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).

View File

@@ -1,4 +0,0 @@
from onnxruntime import InferenceSession
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])

View File

@@ -1,133 +0,0 @@
import argparse
import os
import random
import cv2
import hdf5storage as h5
import matplotlib.pyplot as plt
import numpy as np
import paddle
import paddle.nn.functional as F
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
nargs="?",
type=str,
default="~/datasets/doc3d/",
help="Path to the downloaded dataset",
)
parser.add_argument(
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
)
parser.add_argument(
"--output",
nargs="?",
type=str,
default="output.png",
help="Output filename for the image",
)
args = parser.parse_args()
root = os.path.expanduser(args.data_root)
folder = args.folder
dirname = os.path.join(root, "img", str(folder))
choices = [f for f in os.listdir(dirname) if "png" in f]
fname = random.choice(choices)
# Read Image
img_path = os.path.join(dirname, fname)
img = cv2.imread(img_path)
# Read 3D Coords
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# scale wc
# value obtained from the entire dataset
xmx, xmn, ymx, ymn, zmx, zmn = (
1.2539363,
-1.2442188,
1.2396319,
-1.2289206,
0.6436657,
-0.67492497,
)
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
# Read Backward Map
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
bm = h5.loadmat(bm_path)["bm"]
# Read UV Map
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Depth Map
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
# do some clipping and scaling to display it
dmap[dmap > 30.0] = 30
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
# Read Normal Map
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Albedo
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
alb = cv2.imread(alb_path)
# Read Checkerboard Image
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
recon = cv2.imread(recon_path)
# Display image and GTs
# use the backward mapping to dewarp the image
# scale bm to -1.0 to 1.0
bm_ = bm / np.array([448, 448])
bm_ = (bm_ - 0.5) * 2
bm_ = np.reshape(bm_, (1, 448, 448, 2))
bm_ = paddle.to_tensor(bm_, dtype="float32")
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
img_ = np.expand_dims(img_, 0)
img_ = paddle.to_tensor(img_, dtype="float32")
uw = F.grid_sample(img_, bm_)
uw = uw[0].numpy().transpose((1, 2, 0))
f, axrr = plt.subplots(2, 5)
for ax in axrr:
for a in ax:
a.set_xticks([])
a.set_yticks([])
axrr[0][0].imshow(img)
axrr[0][0].title.set_text("image")
axrr[0][1].imshow(wc)
axrr[0][1].title.set_text("3D coords")
axrr[0][2].imshow(bm[:, :, 0])
axrr[0][2].title.set_text("bm 0")
axrr[0][3].imshow(bm[:, :, 1])
axrr[0][3].title.set_text("bm 1")
if uv is None:
uv = np.zeros_like(img)
axrr[0][4].imshow(uv)
axrr[0][4].title.set_text("uv map")
axrr[1][0].imshow(dmap)
axrr[1][0].title.set_text("depth map")
axrr[1][1].imshow(norm)
axrr[1][1].title.set_text("normal map")
axrr[1][2].imshow(alb)
axrr[1][2].title.set_text("albedo")
axrr[1][3].imshow(recon)
axrr[1][3].title.set_text("checkerboard")
axrr[1][4].imshow(uw)
axrr[1][4].title.set_text("gt unwarped")
plt.tight_layout()
plt.savefig(args.output)

View File

@@ -1,21 +0,0 @@
import cv2
import numpy as np
import paddle
from . import DOC_TR
from .utils import to_tensor, to_image
def dewarp_image(image):
img = cv2.resize(image, (288, 288)).astype(np.float32)
y = to_tensor(image)
img = np.transpose(img, (2, 0, 1))
bm = DOC_TR.run(None, {"image": img[None,]})[0]
bm = paddle.to_tensor(bm)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
return to_image(out)

View File

@@ -1,161 +0,0 @@
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

View File

@@ -1,129 +0,0 @@
import collections
import os
import random
import albumentations as A
import cv2
import hdf5storage as h5
import numpy as np
import paddle
from paddle import io
# Set random seed
random.seed(12345678)
class Doc3dDataset(io.Dataset):
def __init__(self, root, split="train", is_augment=False, image_size=512):
self.root = os.path.expanduser(root)
self.split = split
self.is_augment = is_augment
self.files = collections.defaultdict(list)
self.image_size = (
image_size if isinstance(image_size, tuple) else (image_size, image_size)
)
# Augmentation
self.augmentation = A.Compose(
[
A.ColorJitter(),
]
)
for split in ["train", "val"]:
path = os.path.join(self.root, split + ".txt")
file_list = []
with open(path, "r") as file:
file_list = [file_id.rstrip() for file_id in file.readlines()]
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
image_name = self.files[self.split][index]
# Read image
image_path = os.path.join(self.root, "img", image_name + ".png")
image = cv2.imread(image_path)
# Read 3D Coordinates
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read backward map
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
bm = h5.loadmat(bm_path)["bm"]
image, bm = self.transform(wc, bm, image)
return image, bm
def tight_crop(self, wc: np.ndarray):
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
np.uint8
)
mask_size = mask.shape
[y, x] = mask.nonzero()
min_x = min(x)
max_x = max(x)
min_y = min(y)
max_y = max(y)
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
s = 10
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
cx1 = random.randint(0, 2 * s)
cx2 = random.randint(0, 2 * s) + 1
cy1 = random.randint(0, 2 * s)
cy2 = random.randint(0, 2 * s) + 1
wc = wc[cy1:-cy2, cx1:-cx2, :]
top: int = min_y - s + cy1
bottom: int = mask_size[0] - max_y - s + cy2
left: int = min_x - s + cx1
right: int = mask_size[1] - max_x - s + cx2
top = np.clip(top, 0, mask_size[0])
bottom = np.clip(bottom, 1, mask_size[0] - 1)
left = np.clip(left, 0, mask_size[1])
right = np.clip(right, 1, mask_size[1] - 1)
return wc, top, bottom, left, right
def transform(self, wc, bm, img):
wc, top, bottom, left, right = self.tight_crop(wc)
img = img[top:-bottom, left:-right, :]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.is_augment:
img = self.augmentation(image=img)["image"]
# resize image
img = cv2.resize(img, self.image_size)
img = img.astype(np.float32) / 255.0
img = img.transpose(2, 0, 1)
# resize bm
bm = bm.astype(np.float32)
bm[:, :, 1] = bm[:, :, 1] - top
bm[:, :, 0] = bm[:, :, 0] - left
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
bm0 = bm0 * self.image_size[0]
bm1 = bm1 * self.image_size[1]
bm = np.stack([bm0, bm1], axis=-1)
img = paddle.to_tensor(img).astype(dtype="float32")
bm = paddle.to_tensor(bm).astype(dtype="float32")
return img, bm

View File

@@ -1,66 +0,0 @@
import argparse
import os
import paddle
from GeoTr import GeoTr
def export(args):
model_path = args.model
imgsz = args.imgsz
format = args.format
model = GeoTr()
checkpoint = paddle.load(model_path)
model.set_state_dict(checkpoint["model"])
model.eval()
dirname = os.path.dirname(model_path)
if format == "static" or format == "onnx":
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
],
full_graph=True,
)
paddle.jit.save(model, os.path.join(dirname, "model"))
if format == "onnx":
onnx_path = os.path.join(dirname, "model.onnx")
os.system(
f"paddle2onnx --model_dir {dirname}"
" --model_filename model.pdmodel"
" --params_filename model.pdiparams"
f" --save_file {onnx_path}"
" --opset_version 11"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export model")
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--imgsz", type=int, default=288, help="The size of input image"
)
parser.add_argument(
"--format",
type=str,
default="static",
help="The format of exported model, which can be static or onnx",
)
args = parser.parse_args()
export(args)

View File

@@ -1,110 +0,0 @@
import paddle.nn as nn
from .weight_init import weight_init_
class ResidualBlock(nn.Layer):
"""Residual Block with custom normalization."""
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
self.relu = nn.ReLU()
if norm_fn == "group":
num_groups = planes // 8
self.norm1 = nn.GroupNorm(num_groups, planes)
self.norm2 = nn.GroupNorm(num_groups, planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups, planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(planes)
self.norm2 = nn.BatchNorm2D(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2D(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(planes)
self.norm2 = nn.InstanceNorm2D(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2D(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Layer):
"""Basic Encoder with custom normalization."""
def __init__(self, output_dim=128, norm_fn="batch"):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(8, 64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
self.relu1 = nn.ReLU()
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(128, stride=2)
self.layer3 = self._make_layer(192, stride=2)
self.conv2 = nn.Conv2D(192, output_dim, 1)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
weight_init_(m, "Constant", value=1, bias_value=0.0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = layer1, layer2
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
return x

View File

@@ -1,48 +0,0 @@
import copy
import os
import matplotlib.pyplot as plt
import paddle.optimizer as optim
from GeoTr import GeoTr
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
"""
Plot the learning rate scheduler
"""
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
lr = []
for _ in range(epochs):
for _ in range(30):
lr.append(scheduler.get_lr())
optimizer.step()
scheduler.step()
epoch = [float(i) / 30.0 for i in range(len(lr))]
plt.figure()
plt.plot(epoch, lr, ".-", label="Learning Rate")
plt.xlabel("epoch")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Scheduler")
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
plt.close()
if __name__ == "__main__":
model = GeoTr()
schaduler = optim.lr.OneCycleLR(
max_learning_rate=1e-4,
total_steps=1950,
phase_pct=0.1,
end_learning_rate=1e-4 / 2.5e5,
)
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
plot_lr_scheduler(
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
)

View File

@@ -1,124 +0,0 @@
import math
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.initializer as init
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
self.tensors = tensors
self.mask = mask
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
class PositionEmbeddingSine(nn.Layer):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, mask):
assert mask is not None
y_embed = mask.cumsum(axis=1, dtype="float32")
x_embed = mask.cumsum(axis=2, dtype="float32")
if self.normalize:
eps = 1e-06
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = paddle.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
return pos
class PositionEmbeddingLearned(nn.Layer):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
init_Constant = init.Uniform()
init_Constant(self.row_embed.weight)
init_Constant(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = paddle.arange(end=w)
j = paddle.arange(end=h)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
paddle.concat(
[
x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]),
],
axis=-1,
)
.transpose([2, 0, 1])
.unsqueeze(0)
.tile([x.shape[0], 1, 1, 1])
)
return pos
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
N_steps = hidden_dim // 2
if position_embedding in ("v2", "sine"):
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {position_embedding}")
return position_embedding

View File

@@ -1,69 +0,0 @@
import argparse
import cv2
import paddle
from GeoTr import GeoTr
from utils import to_image, to_tensor
def run(args):
image_path = args.image
model_path = args.model
output_path = args.output
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
model = GeoTr()
model.set_state_dict(state_dict)
model.eval()
img_org = cv2.imread(image_path)
img = cv2.resize(img_org, (288, 288))
x = to_tensor(img)
y = to_tensor(img_org)
bm = model(x)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = bm.transpose([0, 2, 3, 1])
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
out_image = to_image(out)
cv2.imwrite(output_path, out_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="predict")
parser.add_argument(
"--image",
"-i",
nargs="?",
type=str,
default="",
help="The path of image",
)
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--output",
"-o",
nargs="?",
type=str,
default="",
help="The path of output",
)
args = parser.parse_args()
print(args)
run(args)

View File

@@ -1,7 +0,0 @@
hdf5storage
loguru
numpy
scipy
opencv-python
matplotlib
albumentations

View File

@@ -1,57 +0,0 @@
import argparse
import glob
import os
import random
from pathlib import Path
from loguru import logger
random.seed(1234567)
def run(args):
data_root = os.path.expanduser(args.data_root)
ratio = args.train_ratio
data_path = os.path.join(data_root, "img", "*", "*.png")
img_list = glob.glob(data_path, recursive=True)
sorted(img_list)
random.shuffle(img_list)
train_size = int(len(img_list) * ratio)
train_text_path = os.path.join(data_root, "train.txt")
with open(train_text_path, "w") as file:
for item in img_list[:train_size]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
val_text_path = os.path.join(data_root, "val.txt")
with open(val_text_path, "w") as file:
for item in img_list[train_size:]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
logger.info(f"TRAIN LABEL: {train_text_path}")
logger.info(f"VAL LABEL: {val_text_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="~/datasets/doc3d",
help="Data path to load data",
)
parser.add_argument(
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
)
args = parser.parse_args()
logger.info(args)
run(args)

View File

@@ -1,381 +0,0 @@
import argparse
import inspect
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from loguru import logger
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle_msssim import ms_ssim, ssim
from doc3d_dataset import Doc3dDataset
from GeoTr import GeoTr
from utils import to_image
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
RANK = int(os.getenv("RANK", -1))
def init_seeds(seed=0, deterministic=False):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if deterministic:
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
# i.e. colorstr('blue', 'hello world')
*args, string = (
input if len(input) > 1 else ("blue", "bold", input[0])
) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path,
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
)
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def train(args):
save_dir = Path(args.save_dir)
use_vdl = args.use_vdl
if use_vdl:
from visualdl import LogWriter
log_dir = save_dir / "vdl"
vdl_writer = LogWriter(str(log_dir))
# Directories
weights_dir = save_dir / "weights"
weights_dir.parent.mkdir(parents=True, exist_ok=True)
last = weights_dir / "last.ckpt"
best = weights_dir / "best.ckpt"
# Hyperparameters
# Config
init_seeds(args.seed)
# Train loader
train_dataset = Doc3dDataset(
args.data_root,
split="train",
is_augment=True,
image_size=args.img_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
)
# Validation loader
val_dataset = Doc3dDataset(
args.data_root,
split="val",
is_augment=False,
image_size=args.img_size,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=args.workers
)
# Model
model = GeoTr()
if use_vdl:
vdl_writer.add_graph(
model,
input_spec=[
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
],
)
# Data Parallel Mode
if RANK == -1 and paddle.device.cuda.device_count() > 1:
model = paddle.DataParallel(model)
# Scheduler
scheduler = optim.lr.OneCycleLR(
max_learning_rate=args.lr,
total_steps=args.epochs * len(train_loader),
phase_pct=0.1,
end_learning_rate=args.lr / 2.5e5,
)
# Optimizer
optimizer = optim.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
)
# loss function
l1_loss_fn = nn.L1Loss()
mse_loss_fn = nn.MSELoss()
# Resume
best_fitness, start_epoch = 0.0, 0
if args.resume:
ckpt = paddle.load(args.resume)
model.set_state_dict(ckpt["model"])
optimizer.set_state_dict(ckpt["optimizer"])
scheduler.set_state_dict(ckpt["scheduler"])
best_fitness = ckpt["best_fitness"]
start_epoch = ckpt["epoch"] + 1
# Train
for epoch in range(start_epoch, args.epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img = paddle.to_tensor(img) # NCHW
target = paddle.to_tensor(target) # NHWC
pred = model(img) # NCHW
pred_nhwc = pred.transpose([0, 2, 3, 1])
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
if use_vdl:
vdl_writer.add_scalar(
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/Learning Rate",
float(scheduler.get_lr()),
epoch * len(train_loader) + i,
)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
if i % 10 == 0:
logger.info(
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
)
# Validation
model.eval()
with paddle.no_grad():
avg_ssim = paddle.zeros([])
avg_ms_ssim = paddle.zeros([])
avg_l1_loss = paddle.zeros([])
avg_mse_loss = paddle.zeros([])
for i, (img, target) in enumerate(val_loader):
img = paddle.to_tensor(img)
target = paddle.to_tensor(target)
pred = model(img)
pred_nhwc = pred.transpose([0, 2, 3, 1])
# predict image
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
# calculate ssim
ssim_val = ssim(out, out_gt, data_range=1.0)
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
# calculate fitness
avg_ssim += ssim_val
avg_ms_ssim += ms_ssim_val
avg_l1_loss += loss
avg_mse_loss += mse_loss
if i % 10 == 0:
logger.info(
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
)
if use_vdl and i == 0:
img_0 = to_image(out[0])
img_gt_0 = to_image(out_gt[0])
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
img_1 = to_image(out[1])
img_gt_1 = to_image(out_gt[1])
img_gt_1 = img_gt_1.astype("uint8")
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
img_2 = to_image(out[2])
img_gt_2 = to_image(out_gt[2])
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
avg_ssim /= len(val_loader)
avg_ms_ssim /= len(val_loader)
avg_l1_loss /= len(val_loader)
avg_mse_loss /= len(val_loader)
if use_vdl:
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
# Save
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_fitness": best_fitness,
"epoch": epoch,
}
paddle.save(ckpt, str(last))
if best_fitness < avg_ssim:
best_fitness = avg_ssim
paddle.save(ckpt, str(best))
if use_vdl:
vdl_writer.close()
def main(args):
print_args(vars(args))
args.save_dir = str(
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
)
train(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--data-root",
nargs="?",
type=str,
default="~/datasets/doc3d",
help="The root path of the dataset",
)
parser.add_argument(
"--img-size",
nargs="?",
type=int,
default=288,
help="The size of the input image",
)
parser.add_argument(
"--epochs",
nargs="?",
type=int,
default=65,
help="The number of training epochs",
)
parser.add_argument(
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
)
parser.add_argument(
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
)
parser.add_argument(
"--resume",
nargs="?",
type=str,
default=None,
help="Path to previous saved model to restart from",
)
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
parser.add_argument(
"--project", default=ROOT / "runs/train", help="save to project/name"
)
parser.add_argument("--name", default="exp", help="save to project/name")
parser.add_argument(
"--exist-ok",
action="store_true",
help="existing project/name ok, do not increment",
)
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
args = parser.parse_args()
main(args)

View File

@@ -1,10 +0,0 @@
export OPENCV_IO_ENABLE_OPENEXR=1
export FLAGS_logtostderr=0
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 1e-4 \
--exist-ok \
--use-vdl

View File

@@ -1,67 +0,0 @@
import numpy as np
import paddle
from paddle.nn import functional as F
def to_tensor(img: np.ndarray):
"""
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
Args:
img (numpy.ndarray): The input image as a numpy array.
Returns:
out (paddle.Tensor): The output tensor.
"""
img = img[:, :, ::-1]
img = img.astype("float32") / 255.0
img = img.transpose(2, 0, 1)
out: paddle.Tensor = paddle.to_tensor(img)
out = paddle.unsqueeze(out, axis=0)
return out
def to_image(x: paddle.Tensor):
"""
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
Args:
x (paddle.Tensor): The input tensor.
Returns:
out (numpy.ndarray): The output image as a numpy array.
"""
out: np.ndarray = x.squeeze().numpy()
out = out.transpose(1, 2, 0)
out = out * 255.0
out = out.astype("uint8")
out = out[:, :, ::-1]
return out
def unwarp(img, bm, bm_data_format="NCHW"):
"""
Unwarp an image using a flow field.
Args:
img (paddle.Tensor): The input image.
bm (paddle.Tensor): The flow field.
Returns:
out (paddle.Tensor): The output image.
"""
_, _, h, w = img.shape
if bm_data_format == "NHWC":
bm = bm.transpose([0, 3, 1, 2])
# NCHW
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
# NHWC
bm = bm.transpose([0, 2, 3, 1])
# NCHW
out = F.grid_sample(img, bm)
return out

View File

@@ -1,152 +0,0 @@
import math
import numpy as np
import paddle
import paddle.nn.initializer as init
from scipy import special
def weight_init_(
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
):
"""
In-place params init function.
Usage:
.. code-block:: python
import paddle
import numpy as np
data = np.ones([3, 4], dtype='float32')
linear = paddle.nn.Linear(4, 4)
input = paddle.to_tensor(data)
print(linear.weight)
linear(input)
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
print(linear.weight)
"""
if hasattr(layer, "weight") and layer.weight is not None:
getattr(init, func)(**kwargs)(layer.weight)
if weight_name is not None:
# override weight name
layer.weight.name = weight_name
if hasattr(layer, "bias") and layer.bias is not None:
init.Constant(bias_value)(layer.bias)
if bias_name is not None:
# override bias name
layer.bias.name = bias_name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with paddle.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
tmp = np.random.uniform(
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
).astype(np.float32)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tmp = special.erfinv(tmp)
# Transform to proper mean, std
tmp *= std * math.sqrt(2.0)
tmp += mean
# Clamp to ensure it's in the proper range
tmp = np.clip(tmp, a, b)
tensor.set_value(paddle.to_tensor(tmp))
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor "
"with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
)
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with paddle.no_grad():
paddle.nn.initializer.Normal(0, std)(tensor)
return tensor

26
docker-compose.dev.yml Normal file
View File

@@ -0,0 +1,26 @@
services:
fcb_ai_dev:
image: fcb_ai_dev:0.0.10
build:
context: .
dockerfile: Dockerfile.dev
# 容器名称,可自定义
container_name: fcb_ai_dev
hostname: fcb_ai_dev
# 始终重启容器
restart: always
# 端口映射,根据需要修改主机端口
ports:
- "8022:22"
# 数据卷映射,根据实际路径修改
volumes:
- ./log:/app/log
- ./model:/app/model
# 启用GPU支持
deploy:
resources:
reservations:
devices:
- device_ids: [ '0', '1' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'

View File

@@ -1,6 +1,6 @@
x-env:
&template
image: fcb_photo_review:1.14.6
image: fcb_photo_review:1.15.7
restart: always
x-review:
@@ -31,30 +31,12 @@ x-mask:
driver: 'nvidia'
services:
det_api:
<<: *template
build:
context: .
container_name: det_api
hostname: det_api
volumes:
- ./log:/app/log
- ./model:/app/model
# command: [ 'det_api.py' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
photo_review_1:
<<: *review_template
build:
context: .
container_name: photo_review_1
hostname: photo_review_1
depends_on:
- det_api
command: [ 'photo_review.py', '--clean', 'True' ]
photo_review_2:

Binary file not shown.

153
document/PaddleOCR命令.md Normal file
View File

@@ -0,0 +1,153 @@
# PaddleOCR
------
## 数据集
该部分内容均在PPOCRLabel目录下进行
```bash
# 进入PPOCRLabel目录
cd .\PPOCRLabel\
```
### 打标
可以对PPOCRLabel.py直接使用PyCharm中的Run但是默认是英文的
```bash
# 以中文运行打标应用
python PPOCRLabel.py --lang ch
# 含有关键词提取的打标
python PPOCRLabel.py --lang ch --kie True
```
### 划分数据集
```bash
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data/drivingData
```
------
## 检测模型
先回到项目根目录
### 训练
```bash
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
```
### 测试
```bash
python tools/infer_det.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/1.jpg
```
### 恢复训练
```bash
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.checkpoints=./output/det_v4_bankcard/latest
```
------
## 识别模型
### 训练
```bash
python tools/train.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml
```
### 测试
```bash
python tools/infer_rec.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/crop_img/1_crop_0.jpg
```
------
## 推理模型
### 检测模型转换
```bash
python tools/export_model.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams
```
### 识别模型转换
```bash
python tools/export_model.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams
```
### 检测识别测试
```bash
python tools/infer/predict_system.py --det_model_dir=inference_model/det_v4_bankcard --rec_model_dir=inference_model/rec_v4_bankcard --rec_char_dict_path=ppocr/utils/num_dict.txt --image_dir=train_data/drivingData/1.jpg
```
------
## 移动端模型
### 检测模型转换
```bash
paddle_lite_opt --model_file=inference_model/det_v4_bankcard/inference.pdmodel --param_file=inference_model/det_v4_bankcard/inference.pdiparams --optimize_out=inference_model/det_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
```
### 识别模型转换
```bash
paddle_lite_opt --model_file=inference_model/rec_v4_bankcard/inference.pdmodel --param_file=inference_model/rec_v4_bankcard/inference.pdiparams --optimize_out=inference_model/rec_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
```
------
------
# PaddleNLP
## 数据集
使用Label Studio进行数据标注安装过程省略
```bash
# 打开Anaconda Prompt
# 激活安装Label Studio的环境
conda activate label-studio
# 启动Label Studio
label-studio start
```
[打标流程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/information_extraction/label_studio_doc.md)
### 数据转换
```bash
# 进入PaddleNLP\applications\information_extraction后执行
python label_studio.py --label_studio_file ./document/data/label_studio.json --save_dir ./document/data --splits 0.8 0.1 0.1 --task_type ext
```
------
## 训练模型
```bash
# 进入PaddleNLP\applications\information_extraction\document后执行(双卡训练)
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py --device gpu --logging_steps 5 --save_steps 25 --eval_steps 25 --seed 42 --model_name_or_path uie-x-base --output_dir ./checkpoint/model_best --train_path data/train.txt --dev_path data/dev.txt --max_seq_len 512 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --num_train_epochs 10 --learning_rate 1e-5 --do_train --do_eval --do_export --export_model_dir ./checkpoint/model_best --overwrite_output_dir --disable_tqdm False --metric_for_best_model eval_f1 --load_best_model_at_end True --save_total_limit 1
```
------
参考:
[PaddleOCR训练属于自己的模型详细教程](https://blog.csdn.net/qq_52852432/article/details/131817619?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-0-131817619-blog-124628731.235^v40^pc_relevant_3m_sort_dl_base1&amp;spm=1001.2101.3001.4242.1&amp;utm_relevant_index=3)
[端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/deploy/lite/readme_ch.md)
[PaddleNLP关键信息抽取](https://blog.csdn.net/z5z5z5z56/article/details/130346646)

View File

@@ -0,0 +1,329 @@
anyio 4.0.0
astor 0.8.1
certifi 2019.11.28
chardet 3.0.4
dbus-python 1.2.16
decorator 5.1.1
distro-info 0.23+ubuntu1.1
exceptiongroup 1.1.3
h11 0.14.0
httpcore 1.0.2
httpx 0.25.1
idna 2.8
numpy 1.26.2
opt-einsum 3.3.0
paddlepaddle-gpu 2.6.1.post120
Pillow 10.1.0
pip 24.0
protobuf 4.25.0
PyGObject 3.36.0
python-apt 2.0.1+ubuntu0.20.4.1
requests 2.22.0
requests-unixsocket 0.2.0
setuptools 68.2.2
six 1.14.0
sniffio 1.3.0
unattended-upgrades 0.1
urllib3 1.25.8
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:3.1.0-gpu-cuda12.9-cudnn9.9:
python:3.10.12
Package Version
------------------------ ----------
anyio 4.9.0
certifi 2025.6.15
decorator 5.2.1
exceptiongroup 1.3.0
h11 0.16.0
httpcore 1.0.9
httpx 0.28.1
idna 3.10
networkx 3.4.2
numpy 2.2.6
nvidia-cublas-cu12 12.9.0.13
nvidia-cuda-cccl-cu12 12.9.27
nvidia-cuda-cupti-cu12 12.9.19
nvidia-cuda-nvrtc-cu12 12.9.41
nvidia-cuda-runtime-cu12 12.9.37
nvidia-cudnn-cu12 9.9.0.52
nvidia-cufft-cu12 11.4.0.6
nvidia-cufile-cu12 1.14.0.30
nvidia-curand-cu12 10.3.10.19
nvidia-cusolver-cu12 11.7.4.40
nvidia-cusparse-cu12 12.5.9.5
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.26.5
nvidia-nvjitlink-cu12 12.9.41
nvidia-nvtx-cu12 12.9.19
opt-einsum 3.3.0
paddlepaddle-gpu 3.1.0
pillow 11.2.1
pip 25.1.1
protobuf 6.31.1
setuptools 59.6.0
sniffio 1.3.1
typing_extensions 4.14.0
wheel 0.37.1
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
python:3.10.18
Package Version Editable project location
------------------------- -------------------- -------------------------
aiohappyeyeballs 2.6.1
aiohttp 3.12.13
aiosignal 1.4.0
aistudio_sdk 0.3.5
albucore 0.0.13+pdx
albumentations 1.4.10+pdx
alembic 1.16.2
annotated-types 0.7.0
anyio 4.9.0
astor 0.8.1
asttokens 3.0.0
async-timeout 4.0.3
attrdict3 2.0.2
attrs 25.3.0
babel 2.17.0
bce-python-sdk 0.9.35
beautifulsoup4 4.13.4
blinker 1.9.0
cachetools 6.1.0
certifi 2019.11.28
cffi 1.17.1
chardet 3.0.4
charset-normalizer 3.4.2
chinese-calendar 1.8.0
click 8.2.1
cloudpickle 3.1.1
colorama 0.4.6
colorlog 6.9.0
ConfigSpace 1.2.1
contourpy 1.3.2
cssselect 1.3.0
cssutils 2.11.1
cycler 0.12.1
Cython 3.1.2
dataclasses-json 0.6.7
datasets 3.6.0
dbus-python 1.2.16
decorator 5.2.1
decord 0.6.0
descartes 1.1.0
dill 0.3.4
distro 1.9.0
distro-info 0.23+ubuntu1.1
easydict 1.13
einops 0.8.1
et_xmlfile 2.0.0
exceptiongroup 1.2.2
executing 2.2.0
faiss-cpu 1.8.0.post1
fastapi 0.116.0
filelock 3.18.0
fire 0.7.0
FLAML 2.3.5
Flask 3.1.1
flask-babel 4.0.0
fonttools 4.58.5
frozenlist 1.7.0
fsspec 2025.3.0
ftfy 6.3.1
future 1.0.0
gast 0.3.3
GPUtil 1.4.0
greenlet 3.2.3
h11 0.14.0
h5py 3.14.0
hf-xet 1.1.5
hpbandster 0.7.4
httpcore 1.0.7
httpx 0.28.1
httpx-sse 0.4.1
huggingface-hub 0.33.2
idna 2.8
imageio 2.37.0
imagesize 1.4.1
imgaug 0.4.0+pdx
ipython 8.37.0
itsdangerous 2.2.0
jedi 0.19.2
jieba 0.42.1
Jinja2 3.1.6
jiter 0.10.0
joblib 1.5.1
jsonpatch 1.33
jsonpointer 3.0.0
jsonschema 4.24.0
jsonschema-specifications 2025.4.1
kiwisolver 1.4.8
langchain 0.3.26
langchain-community 0.3.27
langchain-core 0.3.68
langchain-openai 0.3.27
langchain-text-splitters 0.3.8
langsmith 0.4.4
lapx 0.5.11.post1
lazy_loader 0.4
llvmlite 0.44.0
lmdb 1.6.2
lxml 6.0.0
Mako 1.3.10
markdown-it-py 3.0.0
MarkupSafe 3.0.2
marshmallow 3.26.1
matplotlib 3.5.3
matplotlib-inline 0.1.7
mdurl 0.1.2
more-itertools 10.7.0
motmetrics 1.4.0
msgpack 1.1.1
multidict 6.6.3
multiprocess 0.70.12.2
mypy_extensions 1.1.0
netifaces 0.11.0
networkx 3.4.2
numba 0.61.2
numpy 1.24.4
nuscenes-devkit 1.1.11+pdx
onnx 1.17.0
onnxoptimizer 0.3.13
openai 1.93.1
opencv-contrib-python 4.10.0.84
openpyxl 3.1.5
opt-einsum 3.3.0
optuna 4.4.0
orjson 3.10.18
packaging 24.2
paddle2onnx 2.0.2rc3
paddle3d 0.0.0
paddleclas 2.6.0
paddledet 0.0.0
paddlefsl 1.1.0
paddlenlp 2.8.0.post0
paddlepaddle-gpu 3.0.0
paddleseg 0.0.0.dev0
paddlets 1.1.0
paddlex 3.1.2 /root/PaddleX
pandas 1.3.5
parso 0.8.4
patsy 1.0.1
pexpect 4.9.0
pillow 11.1.0
pip 25.1.1
polygraphy 0.49.24
ppvideo 2.3.0
premailer 3.10.0
prettytable 3.16.0
prompt_toolkit 3.0.51
propcache 0.3.2
protobuf 6.30.1
psutil 7.0.0
ptyprocess 0.7.0
pure_eval 0.2.3
py-cpuinfo 9.0.0
pyarrow 20.0.0
pybind11 2.13.6
pybind11-stubgen 2.5.1
pyclipper 1.3.0.post6
pycocotools 2.0.8
pycparser 2.22
pycryptodome 3.23.0
pydantic 2.11.7
pydantic_core 2.33.2
pydantic-settings 2.10.1
Pygments 2.19.2
PyGObject 3.36.0
PyMatting 1.1.14
pyod 2.0.5
pypandoc 1.15
pyparsing 3.2.3
pypdfium2 4.30.1
pyquaternion 0.9.9
Pyro4 4.82
python-apt 2.0.1+ubuntu0.20.4.1
python-dateutil 2.9.0.post0
python-docx 1.2.0
python-dotenv 1.1.1
pytz 2025.2
PyWavelets 1.3.0
PyYAML 6.0.2
RapidFuzz 3.13.0
rarfile 4.2
ray 2.47.1
referencing 0.36.2
regex 2024.11.6
requests 2.32.4
requests-toolbelt 1.0.0
requests-unixsocket 0.2.0
rich 14.0.0
rpds-py 0.26.0
ruamel.yaml 0.18.14
ruamel.yaml.clib 0.2.12
safetensors 0.5.3
scikit-image 0.25.2
scikit-learn 1.3.2
scipy 1.15.3
seaborn 0.13.2
sentencepiece 0.2.0
seqeval 1.2.2
serpent 1.41
setuptools 68.2.2
shap 0.48.0
Shapely 1.8.5.post1
shellingham 1.5.4
six 1.14.0
sklearn 0.0
slicer 0.0.8
sniffio 1.3.1
soundfile 0.13.1
soupsieve 2.7
SQLAlchemy 2.0.41
stack-data 0.6.3
starlette 0.46.2
statsmodels 0.14.1
tenacity 9.1.2
tensorboardX 2.6.4
tensorrt 10.5.0
termcolor 3.1.0
terminaltables 3.1.10
threadpoolctl 3.6.0
tifffile 2025.5.10
tiktoken 0.9.0
tokenizers 0.19.1
tomli 2.2.1
tool_helpers 0.1.2
tqdm 4.67.1
traitlets 5.14.3
typeguard 4.4.4
typer 0.16.0
typing_extensions 4.14.1
typing-inspect 0.9.0
typing-inspection 0.4.1
tzdata 2025.2
ujson 5.10.0
unattended-upgrades 0.1
urllib3 1.25.8
uvicorn 0.35.0
visualdl 2.5.3
Wand 0.6.13
wcwidth 0.2.13
Werkzeug 3.1.3
xmltodict 0.14.2
xxhash 3.5.0
yacs 0.1.8
yarl 1.20.1
zstandard 0.23.0

Binary file not shown.

View File

@@ -8,6 +8,7 @@ LOG_PATHS = [
f"log/{HOSTNAME}/ucloud",
f"log/{HOSTNAME}/error",
f"log/{HOSTNAME}/qr",
f"log/{HOSTNAME}/sql",
]
for path in LOG_PATHS:
if not os.path.exists(path):
@@ -74,6 +75,16 @@ LOGGING_CONFIG = {
'backupCount': 14,
'encoding': 'utf-8',
},
'sql': {
'class': 'logging.handlers.TimedRotatingFileHandler',
'level': 'INFO',
'formatter': 'standard',
'filename': f'log/{HOSTNAME}/sql/fcb_photo_review_sql.log',
'when': 'midnight',
'interval': 1,
'backupCount': 14,
'encoding': 'utf-8',
},
},
# loggers定义了日志记录器
@@ -98,5 +109,10 @@ LOGGING_CONFIG = {
'level': 'DEBUG',
'propagate': False,
},
'sql': {
'handlers': ['console', 'sql'],
'level': 'DEBUG',
'propagate': False,
},
},
}

View File

@@ -62,7 +62,7 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
captured_image, offset_x, offset_y = image_util.expand_to_a4_size(captured_image)
cv2.imwrite(temp_file.name, captured_image)
try:
layouts = util.get_ocr_layout(OCR, temp_file.name)
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
except TypeError:
# 如果是类型错误,大概率是没识别到文字
layouts = []
@@ -100,7 +100,7 @@ def get_mask_layout(image, name, id_card_num):
result = []
try:
try:
layouts = util.get_ocr_layout(OCR, temp_file.name)
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
# layouts = OCR.parse({"doc": temp_file.name})["layout"]
except TypeError:
# 如果是类型错误,大概率是没识别到文字
@@ -160,6 +160,8 @@ def get_mask_layout(image, name, id_card_num):
result += _find_boxes_by_keys(ID_CARD_NUM_KEYS)
return result
except MemoryError as e:
raise e
except Exception as e:
logging.error("涂抹时出错!", exc_info=e)
return result
@@ -196,7 +198,9 @@ def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
return do_mask, i
# 打开图片
image = image_util.read(img_url)
image, _ = image_util.read(img_url)
if image is None:
return False, image
original_image = image
is_masked, image = _mask(image, name, id_card_num, color)
if not is_masked:

View File

@@ -23,7 +23,7 @@ def check_error(error_ocr):
image = mask_photo(img_url, name, id_card_num, (0, 0, 0))[1]
final_img_url = ufile.get_private_url(error_ocr.cfjaddress, "drg100")
final_image = image_util.read(final_img_url)
final_image, _ = image_util.read(final_img_url)
return image_util.combined(final_image, image)

View File

@@ -13,14 +13,14 @@ from photo_review import auto_photo_review, SEND_ERROR_EMAIL
# 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生
if __name__ == '__main__':
program_name = '照片审核自动识别脚本'
program_name = "照片审核自动识别脚本"
logging.config.dictConfig(LOGGING_CONFIG)
parser = argparse.ArgumentParser()
parser.add_argument("--clean", default=False, type=bool, help="是否将识别中的案子改为待识别状态")
args = parser.parse_args()
if args.clean:
# 主要用于启动时,清除仍在涂抹中的案子
# 主要用于启动时,清除仍在识别中的案子
session = MysqlSession()
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == "2").values(exsuccess_flag="1"))
session.execute(update_flag)
@@ -34,7 +34,7 @@ if __name__ == '__main__':
logging.info(f"{program_name}】开始运行")
auto_photo_review.main()
except Exception as e:
error_logger = logging.getLogger('error')
error_logger = logging.getLogger("error")
error_logger.error(traceback.format_exc())
if SEND_ERROR_EMAIL:
send_error_email(program_name, repr(e), traceback.format_exc())

View File

@@ -2,9 +2,9 @@ import jieba
from paddlenlp import Taskflow
from paddleocr import PaddleOCR
'''
"""
项目配置
'''
"""
# 每次从数据库获取的案子数量
PHHD_BATCH_SIZE = 10
# 没有查询到案子的等待时间(分钟)
@@ -18,35 +18,35 @@ LAYOUT_ANALYSIS = False
信息抽取关键词配置
"""
# 患者姓名
PATIENT_NAME = ['患者姓名']
PATIENT_NAME = ["患者姓名"]
# 入院日期
ADMISSION_DATE = ['入院日期']
ADMISSION_DATE = ["入院日期"]
# 出院日期
DISCHARGE_DATE = ['出院日期']
DISCHARGE_DATE = ["出院日期"]
# 发生医疗费
MEDICAL_EXPENSES = ['费用总额']
MEDICAL_EXPENSES = ["费用总额"]
# 个人现金支付
PERSONAL_CASH_PAYMENT = ['个人现金支付']
PERSONAL_CASH_PAYMENT = ["个人现金支付"]
# 个人账户支付
PERSONAL_ACCOUNT_PAYMENT = ['个人账户支付']
PERSONAL_ACCOUNT_PAYMENT = ["个人账户支付"]
# 个人自费金额
PERSONAL_FUNDED_AMOUNT = ['自费金额', '个人自费']
PERSONAL_FUNDED_AMOUNT = ["自费金额", "个人自费"]
# 医保类别
MEDICAL_INSURANCE_TYPE = ['医保类型']
MEDICAL_INSURANCE_TYPE = ["医保类型"]
# 就诊医院
HOSPITAL = ['医院']
HOSPITAL = ["医院"]
# 就诊科室
DEPARTMENT = ['科室']
DEPARTMENT = ["科室"]
# 主治医生
DOCTOR = ['主治医生']
DOCTOR = ["主治医生"]
# 住院号
ADMISSION_ID = ['住院号']
ADMISSION_ID = ["住院号"]
# 医保结算单号码
SETTLEMENT_ID = ['医保结算单号码']
SETTLEMENT_ID = ["医保结算单号码"]
# 年龄
AGE = ['年龄']
AGE = ["年龄"]
# 大写总额
UPPERCASE_MEDICAL_EXPENSES = ['大写总额']
UPPERCASE_MEDICAL_EXPENSES = ["大写总额"]
SETTLEMENT_LIST_SCHEMA = \
(PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES + PERSONAL_CASH_PAYMENT
@@ -58,47 +58,55 @@ DISCHARGE_RECORD_SCHEMA = \
COST_LIST_SCHEMA = PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES
'''
"""
别名配置
'''
"""
# 使用别名中的value替换key。考虑到效率问题只会替换第一个匹配到的key。
HOSPITAL_ALIAS = {
'沐阳': ['沭阳'],
'连水': ['涟水'],
'唯宁': ['睢宁'], # 雕宁
'九〇四': ['904'],
'漂水': ['溧水'],
"沐阳": ["沭阳"],
"连水": ["涟水"],
"唯宁": ["睢宁"], # 雕宁
"九〇四": ["904"],
"漂水": ["溧水"],
}
DEPARTMENT_ALIAS = {
'耳鼻喉': ['耳鼻咽喉'],
'急症': ['急诊'],
"耳鼻喉": ["耳鼻咽喉"],
"急症": ["急诊"],
}
'''
"""
搜索过滤配置
'''
"""
# 默认会过滤单字
HOSPITAL_FILTER = ['医院', '人民', '第一', '第二', '第三', '大学', '附属']
HOSPITAL_FILTER = ["医院", "人民", "第一", "第二", "第三", "大学", "附属"]
DEPARTMENT_FILTER = ['', '', '西', '']
DEPARTMENT_FILTER = ["", "", "西", ""]
'''
"""
分词配置
'''
jieba.suggest_freq(('肿瘤', '医院'), True)
jieba.suggest_freq(('', ''), True)
jieba.suggest_freq(('感染', ''), True)
jieba.suggest_freq(('', ''), True)
jieba.suggest_freq(('', ''), True)
"""
jieba.suggest_freq(("肿瘤", "医院"), True)
jieba.suggest_freq(("", ""), True)
jieba.suggest_freq(("感染", ""), True)
jieba.suggest_freq(("", ""), True)
jieba.suggest_freq(("", ""), True)
'''
"""
模型配置
'''
SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA, model='uie-x-base',
task_path='model/settlement_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
DISCHARGE_IE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, model='uie-x-base',
task_path='model/discharge_record_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
COST_IE = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base', device_id=1,
task_path='model/cost_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
"""
SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
task_path="model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
task_path="model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", device_id=1,
task_path="model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
OCR = PaddleOCR(use_angle_cls=False, show_log=False, gpu_id=1, det_db_box_thresh=0.3)
OCR = PaddleOCR(
device="gpu:0",
ocr_version="PP-OCRv4",
use_textline_orientation=False,
# 检测像素阈值,输出的概率图中,得分大于该阈值的像素点才会被认为是文字像素点
text_det_thresh=0.1,
# 检测框阈值,检测结果边框内,所有像素点的平均得分大于该阈值时,该结果会被认为是文字区域
text_det_box_thresh=0.3,
)

View File

@@ -22,11 +22,11 @@ from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_E
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES, \
UPPERCASE_MEDICAL_EXPENSES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, DEPARTMENT_FILTER
from ucloud import ufile
from ucloud import ufile, BUCKET
from util import image_util, util, html_util
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_id, handle_age, parse_money, \
parse_hospital
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_age, parse_money, \
parse_hospital, handle_doctor, handle_text, handle_admission_id, handle_settlement_id
# 合并信息抽取结果
@@ -36,18 +36,25 @@ def merge_result(result1, result2):
return result1
def ie_temp_image(ie, ocr, image):
def ie_temp_image(ie, ocr, image, is_screenshot=False):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, image)
ie_result = []
ocr_pure_text = ''
angle = '0'
try:
layout = util.get_ocr_layout(ocr, temp_file.name)
layout, angle = util.get_ocr_layout(ocr, temp_file.name, is_screenshot)
if not layout:
# 无识别结果
ie_result = []
else:
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
for lay in layout:
ocr_pure_text += lay[1]
except MemoryError as e:
# 显存不足时应该抛出错误,让程序重启,同时释放显存
raise e
except Exception as e:
logging.error("信息抽取时出错", exc_info=e)
finally:
@@ -55,7 +62,7 @@ def ie_temp_image(ie, ocr, image):
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
return ie_result
return ie_result, ocr_pure_text, angle
# 关键信息提取
@@ -147,12 +154,16 @@ def get_better_image_from_qrcode(image, image_id, dpi=150):
# 关键信息提取
def information_extraction(ie, phrecs, identity):
result = {}
ocr_text = ''
for phrec in phrecs:
img_path = ufile.get_private_url(phrec.cfjaddress)
if not img_path:
continue
image = image_util.read(img_path)
image, exif_data = image_util.read(img_path)
if image is None:
# 图片可能因为某些原因获取不到
continue
# 尝试从二维码中获取高清图片
better_image, text = get_better_image_from_qrcode(image, phrec.cfjaddress)
@@ -165,7 +176,7 @@ def information_extraction(ie, phrecs, identity):
if text:
info_extract = ie(text)[0]
else:
info_extract = ie_temp_image(ie, OCR, image)
info_extract = ie_temp_image(ie, OCR, image, True)[0]
ie_result = {'result': info_extract, 'angle': '0'}
now = util.get_default_datetime()
@@ -183,25 +194,20 @@ def information_extraction(ie, phrecs, identity):
result = merge_result(result, ie_result['result'])
else:
is_screenshot = image_util.is_screenshot(image, exif_data)
target_images = []
# target_images += detector.request_book_areas(image) # 识别文档区域并裁剪
if not target_images:
target_images.append(image) # 识别失败
angle_count = defaultdict(int, {'0': 0}) # 分割后图片的最优角度统计
for target_image in target_images:
# dewarped_image = dewarp.dewarp_image(target_image) # 去扭曲
dewarped_image = target_image
angles = image_util.parse_rotation_angles(dewarped_image)
split_results = image_util.split(dewarped_image)
split_results = image_util.split(target_image)
for split_result in split_results:
if split_result['img'] is None or split_result['img'].size == 0:
continue
rotated_img = image_util.rotate(split_result['img'], int(angles[0]))
ie_results = [{'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[0]}]
if not ie_results[0]['result'] or len(ie_results[0]['result']) < len(ie.kwargs.get('schema')):
rotated_img = image_util.rotate(split_result['img'], int(angles[1]))
ie_results.append({'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[1]})
ie_temp_result = ie_temp_image(ie, OCR, split_result['img'], is_screenshot)
ocr_text += ie_temp_result[1]
ie_results = [{'result': ie_temp_result[0], 'angle': ie_temp_result[2]}]
now = util.get_default_datetime()
best_angle = ['0', 0]
for ie_result in ie_results:
@@ -231,13 +237,15 @@ def information_extraction(ie, phrecs, identity):
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
cv2.imwrite(temp_file.name, image)
try:
ufile.upload_file(phrec.cfjaddress, temp_file.name)
if img_angle != '0':
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'旋转图片[{phrec.cfjaddress}]替换成功,已旋转{img_angle}度。')
# 修正旋转角度
for zx_ie_result in zx_ie_results:
zx_ie_result.rotation_angle -= int(img_angle)
else:
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'高清图片[{phrec.cfjaddress}]替换成功!')
except Exception as e:
logging.error(f'上传图片({phrec.cfjaddress})失败', exc_info=e)
@@ -247,8 +255,21 @@ def information_extraction(ie, phrecs, identity):
session = MysqlSession()
session.add_all(zx_ie_results)
session.commit()
session.close()
# # 添加清晰度测试
# if better_image is None:
# # 替换后图片默认清晰
# clarity_result = image_util.parse_clarity(image)
# unsharp_flag = 0 if (clarity_result[0] == 0 and clarity_result[1] >= 0.8) else 1
# update_clarity = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
# cfjaddress2=json.dumps(clarity_result),
# unsharp_flag=unsharp_flag,
# ))
# session.execute(update_clarity)
# session.commit()
# session.close()
result['ocr_text'] = ocr_text
return result
@@ -293,7 +314,7 @@ def save_or_update_ie(table, pk_phhd, data):
if db_data:
# 更新
db_data.update_time = now
db_data.creator = HOSTNAME
db_data.updater = HOSTNAME
for k, v in data.items():
setattr(db_data, k, v)
else:
@@ -379,8 +400,8 @@ def settlement_task(pk_phhd, settlement_list, identity):
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
"medical_insurance_type_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE)),
"admission_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
"settlement_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
"admission_id": handle_admission_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
"settlement_id": handle_settlement_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
}
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
@@ -394,6 +415,10 @@ def settlement_task(pk_phhd, settlement_list, identity):
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES))
settlement_data["medical_expenses_str"] = handle_original_data(parse_money_result[0])
settlement_data["medical_expenses"] = parse_money_result[1]
if not settlement_data["settlement_id"]:
# 如果没有结算单号就填住院号
settlement_data["settlement_id"] = settlement_data["admission_id"]
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
@@ -408,9 +433,10 @@ def discharge_task(pk_phhd, discharge_record, identity):
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
"doctor": handle_name(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
"admission_id": handle_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
"doctor": handle_doctor(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
"admission_id": handle_admission_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
"age": handle_age(get_best_value_in_keys(discharge_record_ie_result, AGE)),
"content": handle_text(discharge_record_ie_result['ocr_text']),
}
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
@@ -476,7 +502,8 @@ def cost_task(pk_phhd, cost_list, identity):
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES))
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES)),
"content": handle_text(cost_list_ie_result['ocr_text']),
}
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])

View File

@@ -1,16 +1,11 @@
numpy==1.26.4
onnxconverter-common==1.14.0
aistudio_sdk==0.2.6
onnxconverter-common==1.15.0
onnxruntime-gpu==1.22.0
OpenCC==1.1.6
opencv-python==4.6.0.66
paddle2onnx==1.2.3
paddleclas==2.5.2
paddlenlp==2.6.1
paddleocr==2.7.3
pillow==10.4.0
paddlenlp==3.0.0b4
paddleocr==3.1.1
PyMuPDF==1.26.3
pymysql==1.1.1
requests==2.32.3
sqlacodegen==2.3.0.post1
sqlalchemy==1.4.52
tenacity==8.5.0
ufile==3.2.9
zxing-cpp==2.2.0
ufile==3.2.11
zxing-cpp==2.3.0

25
tool/batch_download.py Normal file
View File

@@ -0,0 +1,25 @@
# 批量下载图片
import pandas as pd
import requests
from ucloud import ufile
if __name__ == '__main__':
# 读取xlsx文件
file_path = r'D:\Echo\Downloads\Untitled.xlsx' # 将 your_file.xlsx 替换为你的文件路径
df = pd.read_excel(file_path)
# 循环读取 cfjaddress 列
for address in df['cfjaddress']:
img_url = ufile.get_private_url(address)
# 下载图片并保存到本地
response = requests.get(img_url)
# 检查请求是否成功
if response.status_code == 200:
# 定义保存图片的路径
with open(f'../img/{address}', 'wb') as file:
file.write(response.content)
print(f"{address}下载成功")
else:
print(f"{address}下载失败,状态码: {response.status_code}")

View File

@@ -0,0 +1,43 @@
import pandas as pd
from db import MysqlSession
from db.mysql import ZxPhrec
def check_unclarity():
file_path = r'D:\Echo\Downloads\unclare.xlsx' # 将 your_file.xlsx 替换为你的文件路径
df = pd.read_excel(file_path)
# 循环读取 pk_phrec 列
output_file_path = 'check_unclarity_result.txt'
with open(output_file_path, 'w') as f:
for pk_phrec in df['pk_phrec']:
session = MysqlSession()
phrec = session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag).filter(
ZxPhrec.pk_phrec == pk_phrec
).one_or_none()
session.close()
if phrec and phrec.unsharp_flag == 0:
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
def check_clarity():
file_path = r'D:\Echo\Downloads\unclare.xlsx'
df = pd.read_excel(file_path)
session = MysqlSession()
phrecs = (session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag)
.filter(ZxPhrec.pk_phrec >= 30810206)
.filter(ZxPhrec.pk_phrec <= 31782247)
.filter(ZxPhrec.unsharp_flag == 1)
.all())
session.close()
output_file_path = 'check_clarity_result.txt'
with open(output_file_path, 'w') as f:
for phrec in phrecs:
if phrec.pk_phrec not in df['pk_phrec'].to_list():
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
if __name__ == '__main__':
check_clarity()

17
tool/paddle2nb.py Normal file
View File

@@ -0,0 +1,17 @@
# paddle模型转nb格式推荐使用paddlelite2.10
from paddlelite.lite import Opt
if __name__ == '__main__':
# 1. 创建opt实例
opt = Opt()
# 2. 指定输入模型地址
opt.set_model_dir("./model")
# 3. 指定转化类型: arm、x86、opencl、npu
# 一般认为arm为cpuopencl为gpu
opt.set_valid_places("arm")
# 4. 指定模型转化类型: naive_buffer、protobuf
opt.set_model_type("naive_buffer")
# 4. 输出模型地址
opt.set_optimize_out("model")
# 5. 执行模型优化
opt.run()

21
update_dev.sh Normal file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
# 项目更新脚本
echo "开始更新测试项目..."
# 备份docker-compose配置
cp -i docker-compose.dev.yml docker-compose-backup.dev.yml
# 拉取最新的git
git pull
# 构建新镜像
docker-compose -f docker-compose.dev.yml build
# 停止旧的容器
docker-compose -f docker-compose-backup.dev.yml down
# 启动新的容器
docker-compose -f docker-compose.dev.yml up -d
# 删除docker-compose备份
rm -f docker-compose-backup.dev.yml
# 查看容器运行情况
docker ps
# 查看镜像
docker images
# 结束
echo "测试项目更新完成,请确认容器版本正确,自行删除过期镜像。"

View File

@@ -2,7 +2,7 @@ import logging
import re
from datetime import datetime
from util import util
from util import util, string_util
# 处理金额类数据
@@ -12,6 +12,7 @@ def handle_decimal(string):
string = re.sub(r'[^0-9.]', '', string)
if not string:
return ""
string = string_util.full_to_half(string)
if "." not in string:
if len(string) > 2:
result = string[:-2] + "." + string[-2:]
@@ -89,13 +90,17 @@ def handle_date(string):
def handle_hospital(string):
if not string:
return ""
return string[:255]
# 只允许汉字、数字
string = re.sub(r'[^⺀-鿿0-9]', '', string)
return string[:200]
def handle_department(string):
if not string:
return ""
return string[:255]
# 只允许汉字
string = re.sub(r'[^⺀-鿿]', '', string)
return string[:200]
def parse_department(string):
@@ -119,7 +124,13 @@ def parse_department(string):
def handle_name(string):
if not string:
return ""
return re.sub(r'[^⺀-鿿·]', '', string)[:30]
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
def handle_doctor(string):
if not string:
return ""
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
# 处理医保类型数据
@@ -150,18 +161,29 @@ def handle_original_data(string):
# 处理id类数据
def handle_id(string):
def handle_admission_id(string):
if not string:
return ""
# 只允许字母和数字
string = re.sub(r'[^0-9a-zA-Z]', '', string)
# 防止过长存入数据库失败
return string[:50]
return string[:20]
def handle_settlement_id(string):
if not string:
return ""
# 只允许字母和数字
string = re.sub(r'[^0-9a-zA-Z]', '', string)
# 防止过长存入数据库失败
return string[:30]
# 处理年龄类数据
def handle_age(string):
if not string:
return ""
string = string.split("")[0]
string = string_util.full_to_half(string.split("")[0])
num = re.sub(r'\D', '', string)
return num[-3:]
@@ -178,3 +200,9 @@ def parse_hospital(string):
split_hospitals = string_without_company.replace("医院", "医院 ")
result += split_hospitals.strip().split(" ")
return result
def handle_text(string):
if not string:
return ""
return string[:16383]

View File

@@ -1,9 +1,12 @@
import logging
import math
import urllib.request
from io import BytesIO
import cv2
import numpy
from PIL import Image
from PIL.ExifTags import TAGS
from paddleclas import PaddleClas
from tenacity import retry, stop_after_attempt, wait_random
@@ -14,20 +17,36 @@ def read(image_path):
"""
从网络或本地读取图片
:param image_path: 网络或本地路径
:return: NumPy数组形式的图片
:return: NumPy数组形式的图片, EXIF数据
"""
if image_path.startswith("http"):
# 发送HTTP请求并获取图像数据
resp = urllib.request.urlopen(image_path, timeout=60)
# 将数据读取为字节流
image_data = resp.read()
else:
with open(image_path, "rb") as f:
image_data = f.read()
# 解析EXIF信息基于原始字节流
exif_data = {}
try:
# 用PIL打开原始字节流
with Image.open(BytesIO(image_data)) as img:
# 获取EXIF字典
exif_info = img._getexif()
if exif_info:
# 将EXIF标签的数字ID转换为可读名称如36867对应"DateTimeOriginal"
for tag_id, value in exif_info.items():
tag_name = TAGS.get(tag_id, tag_id)
exif_data[tag_name] = value
except Exception as e:
logging.error("解析EXIF信息失败", exc_info=e)
# 将字节流转换为NumPy数组
image_np = numpy.frombuffer(image_data, numpy.uint8)
# 解码NumPy数组为OpenCV图像格式
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
else:
image = cv2.imread(image_path)
return image
return image, exif_data
def capture(image, rectangle):
@@ -61,7 +80,7 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
"""
split_result = []
if isinstance(image, str):
image = read(image)
image, _ = read(image)
height, width = image.shape[:2]
hw_ratio = height / width
wh_ratio = width / height
@@ -72,6 +91,8 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
for i in range(math.ceil(height / step)):
offset = round(step * i)
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
if cropped_img.shape[0] > 0:
# 计算误差可能导致图片高度为0此时不添加
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
elif wh_ratio > ratio: # 横向过长
new_img_width = height * ratio
@@ -79,6 +100,8 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
for i in range(math.ceil(width / step)):
offset = round(step * i)
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
if cropped_img.shape[1] > 0:
# 计算误差可能导致图片宽度为0此时不添加
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
else:
split_result.append({"img": image, "x_offset": 0, "y_offset": 0})
@@ -247,3 +270,78 @@ def combined(img1, img2):
combined_img[:height1, :width1] = img1
combined_img[:height2, width1:width1 + width2] = img2
return combined_img
def parse_clarity(image):
"""
判断图片清晰度
:param image: 图片NumPy数组或文件路径
:return: 判断结果及置信度
"""
clarity_result = [1, 0]
model = PaddleClas(inference_model_dir=r"model/clas/clarity_assessment", use_gpu=True)
clas_result = model.predict(input_data=image)
try:
clas_result = next(clas_result)[0]
clarity_result = [clas_result["class_ids"][0], clas_result["scores"][0]]
except Exception as e:
logging.error("获取图片清晰度失败", exc_info=e)
return clarity_result
def is_photo_by_exif(exif_tags):
"""分析EXIF数据判断是否为照片"""
# 照片通常包含的EXIF标签
photo_tags = [
'FNumber', # 光圈
'ExposureTime', # 曝光时间
'ISOSpeedRatings', # ISO
'FocalLength', # 焦距
'LensModel', # 镜头型号
'GPSLatitude' # GPS位置信息
]
# 统计照片相关的EXIF标签数量
photo_tag_count = 0
if exif_tags:
for tag in photo_tags:
if tag in exif_tags:
photo_tag_count += 1
# 如果有2个以上照片相关的EXIF标签倾向于是照片
if photo_tag_count >= 2:
return True
# 不确定是照片返回False
return False
def is_screenshot_by_image_features(image):
"""分析图像特征判断是否为截图"""
# 定义边缘像素标准差阈值,小于此阈值则认为图片是截图
edge_std_threshold = 20.0
try:
# 检查边缘像素的一致性(截图边缘通常更整齐)
edge_pixels = []
# 取图像边缘10像素
edge_pixels.extend(image[:10, :].flatten()) # 顶部边缘
edge_pixels.extend(image[-10:, :].flatten()) # 底部边缘
edge_pixels.extend(image[:, :10].flatten()) # 左侧边缘
edge_pixels.extend(image[:, -10:].flatten()) # 右侧边缘
# 计算边缘像素的标准差(值越小说明越一致)
edge_std = numpy.std(edge_pixels)
logging.info(f"边缘像素标准差: {edge_std}")
return edge_std < edge_std_threshold
except Exception as e:
logging.error("图像特征分析失败", exc_info=e)
return False
def is_screenshot(image, exif_tags):
"""综合判断是否是截图"""
# 先检查EXIF数据
result_of_exif = is_photo_by_exif(exif_tags)
# 如果有明显的照片EXIF信息直接判断为照片
if result_of_exif:
return False
# 分析图像特征
return is_screenshot_by_image_features(image)

View File

@@ -1,3 +1,6 @@
import unicodedata
def blank(string):
"""
判断字符串是否为空或者纯空格
@@ -5,3 +8,23 @@ def blank(string):
:return: 字符串是否为空或者纯空格
"""
return not string or string.isspace()
def full_to_half(string):
"""
全角转半角
:param string: 字符串
:return: 半角字符串
"""
if not isinstance(string, str):
raise TypeError("全角转半角的输入必须是字符串类型")
if not string:
return string
half_string = ''.join([
unicodedata.normalize('NFKC', char) if unicodedata.east_asian_width(char) in ['F', 'W'] else char
for char in string
])
return half_string

View File

@@ -12,9 +12,10 @@ def get_default_datetime():
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def get_ocr_layout(ocr, img_path):
def get_ocr_layout(ocr, img_path, is_screenshot=False):
"""
获取ocr识别的结果转为合适的layout形式
:param is_screenshot: 是否是截图
:param ocr: ocr模型
:param img_path: 图片本地路径
:return:
@@ -36,18 +37,18 @@ def get_ocr_layout(ocr, img_path):
return True
layout = []
ocr_result = ocr.ocr(img_path, cls=False)
ocr_result = ocr_result[0]
ocr_result = ocr.predict(input=img_path, use_doc_orientation_classify=not is_screenshot, use_doc_unwarping=not is_screenshot)
ocr_result = next(ocr_result)
if not ocr_result:
return layout
for segment in ocr_result:
box = segment[0]
return layout, "0"
angle = ocr_result.get("doc_preprocessor_res", {}).get("angle", "0")
for i in range(len(ocr_result.get('rec_texts'))):
box = ocr_result.get("rec_polys")[i].tolist()
box = _get_box(box)
if not _normal_box(box):
continue
text = segment[1][0]
layout.append((box, text))
return layout
layout.append((box, ocr_result.get("rec_texts")[i]))
return layout, str(angle)
def delete_temp_file(temp_files):

View File

@@ -24,7 +24,7 @@ def write_visual_result(image, angle=0, layout=None, result=None):
img_name = img[:last_dot_index]
img_type = img[last_dot_index + 1:]
img_array = image_util.read(image)
img_array, _ = image_util.read(image)
if angle != 0:
img_array = image_util.rotate(img_array, angle)
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
@@ -63,7 +63,7 @@ def visual_model_test(model_type, test_img, task_path, schema):
img["y_offset"] -= offset_y
temp_files_paths.append(temp_file.name)
parsed_doc = util.get_ocr_layout(
parsed_doc, _ = util.get_ocr_layout(
PaddleOCR(det_db_box_thresh=0.3, det_db_thresh=0.1, det_limit_side_len=1248, drop_score=0.3,
save_crop_res=False),
temp_file.name)